-
Notifications
You must be signed in to change notification settings - Fork 25.7k
ln + amax + fp8 quant inductor enablement #109301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109301
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0c1bfbc with merge base 59592ce ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/inductor/test_fp8.py
Outdated
|
|
||
|
|
||
| # Utility functions are copied from | ||
| # https://github.com/pytorch-labs/float8_playground/blob/main/float8_playground/float8_utils.py. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can remove the link, since this repo isn't public yet? I think it's ok to not cite it in the landed version
|
|
||
| def ln_fp8(x: Tensor, scale: float, amax_buffer: Tensor): | ||
| x = torch.nn.functional.layer_norm(x, [hidden_size], weight=None, bias=None, eps=1e-05) | ||
| amax_buffer.fill_(torch.max(torch.abs(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it helps, this function can be moved to after the pointwise stuff in user code. The way we originally wrote this code isn't super intuitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it doesn't matter. This requires a reduction kernel anyways.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
| and len(self.numels) == 2 | ||
| and self.numels[-1] >= 256 | ||
| ) | ||
| # self.no_x_dim = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: no_x_dim related logics needs to be removed so that XBLOCK is configurable.
However, according to @jansel there could be perf regressions if XBLOCK is added.
torch/_inductor/codegen/triton.py
Outdated
| return False | ||
| threshold = { | ||
| ReductionHint.INNER: 1024, | ||
| ReductionHint.INNER: 256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the threshold for persistent kernel needs to be decreased to make sure that fox cases like max(x).to(fp8):
- XBLOCK size fulfill fp8 min_element_per_thread requirements. e.g. For fp8_e5m2 the min_element_per_thread is 4. In this case, XBLOCK size = 4 * NUM_WARPS (which is RBLOCK size / 128 by default) * 32 (warp_size)
- XBLOCK * RBLOCK < 131072, which is Triton maximum tensor numel.
So XBLOCK * RBLOCK = RBLOCK_SIZE * RBLOCK_SIZE < 131072, max RBLOCK_size is 256.
There are lots of things can be tuned potentially. e.g.
- We may only update this rule when fp8 conversion is followed by a reduction. This would need some code refactoring as it seems no way to get this information without actually running
ops.to_dtype(). - We could also reduce NUM_WARPS. However this would reduce parallelism, which doesn't seem to be ideal.
- For a normal reduction kernel (i.e. not a persistent reduction kernel), we may also want to decrease split-reduction threshold so that each block handles fewer number of elements.
|
|
||
| def ln_fp8(x: Tensor, scale: float, amax_buffer: Tensor): | ||
| x = torch.nn.functional.layer_norm(x, [hidden_size], weight=None, bias=None, eps=1e-05) | ||
| amax_buffer.fill_(torch.max(torch.abs(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it doesn't matter. This requires a reduction kernel anyways.
|
Hi @jansel @eellison @Chillee please help review this PR, thanks! |
| self.min_elem_per_thread_reduction_block = 0 | ||
| self.min_elem_per_thread_non_reduction_block = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps make this a list with the same length as len(groups)?
We today we have XBLOCK, YBLOCK, RBLOCK, and a (disabled by config) ZBLOCK.
You might also want to test this on a pointwise kernel that gets 2D tiling. You can trigger that codepath by doing a pointwise kernel with transposed inputs.
In the future we may have tiled reduction kernels as well.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov