KEMBAR78
ln + amax + fp8 quant inductor enablement by ipiszy · Pull Request #109301 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ipiszy
Copy link
Contributor

@ipiszy ipiszy commented Sep 14, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109301

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0c1bfbc with merge base 59592ce (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ipiszy added a commit that referenced this pull request Sep 14, 2023
ghstack-source-id: 3e1ad9f
Pull Request resolved: #109301


# Utility functions are copied from
# https://github.com/pytorch-labs/float8_playground/blob/main/float8_playground/float8_utils.py.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can remove the link, since this repo isn't public yet? I think it's ok to not cite it in the landed version


def ln_fp8(x: Tensor, scale: float, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(x, [hidden_size], weight=None, bias=None, eps=1e-05)
amax_buffer.fill_(torch.max(torch.abs(x)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it helps, this function can be moved to after the pointwise stuff in user code. The way we originally wrote this code isn't super intuitive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it doesn't matter. This requires a reduction kernel anyways.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Sep 19, 2023
ghstack-source-id: 184687f
Pull Request resolved: #109301
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Sep 19, 2023
ghstack-source-id: c582881
Pull Request resolved: #109301
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
and len(self.numels) == 2
and self.numels[-1] >= 256
)
# self.no_x_dim = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: no_x_dim related logics needs to be removed so that XBLOCK is configurable.
However, according to @jansel there could be perf regressions if XBLOCK is added.

return False
threshold = {
ReductionHint.INNER: 1024,
ReductionHint.INNER: 256,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the threshold for persistent kernel needs to be decreased to make sure that fox cases like max(x).to(fp8):

  1. XBLOCK size fulfill fp8 min_element_per_thread requirements. e.g. For fp8_e5m2 the min_element_per_thread is 4. In this case, XBLOCK size = 4 * NUM_WARPS (which is RBLOCK size / 128 by default) * 32 (warp_size)
  2. XBLOCK * RBLOCK < 131072, which is Triton maximum tensor numel.
    So XBLOCK * RBLOCK = RBLOCK_SIZE * RBLOCK_SIZE < 131072, max RBLOCK_size is 256.

There are lots of things can be tuned potentially. e.g.

  1. We may only update this rule when fp8 conversion is followed by a reduction. This would need some code refactoring as it seems no way to get this information without actually running ops.to_dtype().
  2. We could also reduce NUM_WARPS. However this would reduce parallelism, which doesn't seem to be ideal.
  3. For a normal reduction kernel (i.e. not a persistent reduction kernel), we may also want to decrease split-reduction threshold so that each block handles fewer number of elements.


def ln_fp8(x: Tensor, scale: float, amax_buffer: Tensor):
x = torch.nn.functional.layer_norm(x, [hidden_size], weight=None, bias=None, eps=1e-05)
amax_buffer.fill_(torch.max(torch.abs(x)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it doesn't matter. This requires a reduction kernel anyways.

@ipiszy ipiszy marked this pull request as ready for review September 21, 2023 02:23
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 21, 2023

Hi @jansel @eellison @Chillee please help review this PR, thanks!
Meanwhile, I filed an issue in Triton github to collect more feedback triton-lang/triton#2354.

Comment on lines +826 to +827
self.min_elem_per_thread_reduction_block = 0
self.min_elem_per_thread_non_reduction_block = 0
Copy link
Contributor

@jansel jansel Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps make this a list with the same length as len(groups)?

We today we have XBLOCK, YBLOCK, RBLOCK, and a (disabled by config) ZBLOCK.

You might also want to test this on a pointwise kernel that gets 2D tiling. You can trigger that codepath by doing a pointwise kernel with transposed inputs.

In the future we may have tiled reduction kernels as well.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy ipiszy mentioned this pull request Oct 10, 2023
@ipiszy ipiszy closed this Oct 20, 2023
@facebook-github-bot facebook-github-bot deleted the gh/ipiszy@gmail.com/8/head branch November 19, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants