-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Basic fp8 support in Inductor #109168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic fp8 support in Inductor #109168
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109168
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 76b4cf8 with merge base 6b7b9c7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind separating the formatting prs ? it makes it more difficult to review. If you need configuring vscode or something else let me know.
torch/_inductor/codegen/triton.py
Outdated
| def _get_min_elements_per_thread( | ||
| src_dtype: torch.dtype, dst_dtype: torch.dtype | ||
| ) -> int: | ||
| # fp8 data type conversions has min_elements_per_thread requirements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an explanation for why this is the case? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the intrinsics in this file (also documented in the comment below): https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. It uses b32 and e4m3x2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or err, I saw the link, I'm just trying to get a better understanding for why dtype conversions would have min_elements_per_thread requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's because ptx cvt only provides fp8*2 intrinsics, which are used by Triton: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt.
However, I also don't understand why the Triton implementation for fp8_e5m2 requires 4 element per thread. Asked in the Slack channel but didn't get an answer ^^.
torch/_inductor/codegen/triton.py
Outdated
|
|
||
| @staticmethod | ||
| def to_dtype(x, dtype: torch.dtype): | ||
| def to_dtype(x, dtype: torch.dtype, src_dtype: torch.dtype = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not get src_dtype at this stage? Why do we need to plumb src_dtype through the entire lowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because only fp8 conversion has special requirement on min_elements_per_thread. Any suggestions on better ways to implement this logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was just a bit surprised that we don't have this information elsewhere during this lowering. I was thinking whether we can just get the dtype directly from x, but I think in general, we don't actually keep around dtype information at this stage of the lowering.
Yeah will do. I was testing on another H100 host so maybe I did something wrong in the linter configuration... |
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good modulo linting but I think horrace + elias are a better review on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to update reduction heuristics as well ? What about triton templates ?
torch/_inductor/triton_heuristics.py
Outdated
| triton_config(size_hints, bs, 1), | ||
| triton_config(size_hints, 1, bs), | ||
| triton_config( | ||
| size_hints, 32, 32, min_elements_per_thread=min_elements_per_thread |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: doesn't really matter, but maybe slightly less verbose as min_elem_per_thread
| if len(size_hints) == 1: | ||
| if disable_pointwise_autotuning() and not ( | ||
| config.max_autotune or config.max_autotune_pointwise | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for fewer changes you could consider triton_config = functools.partial(triton_config, min_elements_per_thread=min_elements_per_thread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to rename triton_config to avoid lint errors..
torch/_prims_common/__init__.py
Outdated
| _integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) | ||
| _low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) | ||
| _float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64) | ||
| _float_dtypes = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace this with calling dtype.is_floating_point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
test/inductor/test_fp8.py
Outdated
|
|
||
| x_shape = (16, 16, 16) | ||
|
|
||
| with self.assertRaises(Exception): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to be more specific here re. the Exception type (and maybe text)? And below.
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @eellison @aakhundov !
@eellison Changes for reduction heuristics will be in the next PR. I'm still trying to figure out how to do it correctly for different reduction types (persistent_reduction and normal reduction, real reduction (like max) and fused_reduction_pointwise (like layer_norm)).
Triton template changes will come the last. For now we'll just rely on CuBLAS for Gemms since Triton H100 perf is not ideal.
torch/_prims_common/__init__.py
Outdated
| _integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) | ||
| _low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) | ||
| _float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64) | ||
| _float_dtypes = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
| if len(size_hints) == 1: | ||
| if disable_pointwise_autotuning() and not ( | ||
| config.max_autotune or config.max_autotune_pointwise | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to rename triton_config to avoid lint errors..
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
Add basic fp8 support in Inductor, including: * Fix fp8 Triton codegen issues; * Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
|
I'll merge this PR first to unblock fp8 related testing. Meanwhile, I'm working on adding scalar fp8 conversion support in the Triton repo, and re-visit the Pointwise TritonHeuristics related change after the fix. |
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add basic fp8 support in Inductor, including:
Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov