KEMBAR78
Basic fp8 support in Inductor by ipiszy · Pull Request #109168 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ipiszy
Copy link
Contributor

@ipiszy ipiszy commented Sep 13, 2023

Add basic fp8 support in Inductor, including:

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109168

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 76b4cf8 with merge base 6b7b9c7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Sep 13, 2023
ghstack-source-id: 3a3e05e
Pull Request resolved: #109168
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind separating the formatting prs ? it makes it more difficult to review. If you need configuring vscode or something else let me know.

def _get_min_elements_per_thread(
src_dtype: torch.dtype, dst_dtype: torch.dtype
) -> int:
# fp8 data type conversions has min_elements_per_thread requirements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an explanation for why this is the case? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the intrinsics in this file (also documented in the comment below): https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. It uses b32 and e4m3x2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or err, I saw the link, I'm just trying to get a better understanding for why dtype conversions would have min_elements_per_thread requirements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because ptx cvt only provides fp8*2 intrinsics, which are used by Triton: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt.
However, I also don't understand why the Triton implementation for fp8_e5m2 requires 4 element per thread. Asked in the Slack channel but didn't get an answer ^^.


@staticmethod
def to_dtype(x, dtype: torch.dtype):
def to_dtype(x, dtype: torch.dtype, src_dtype: torch.dtype = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not get src_dtype at this stage? Why do we need to plumb src_dtype through the entire lowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because only fp8 conversion has special requirement on min_elements_per_thread. Any suggestions on better ways to implement this logic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was just a bit surprised that we don't have this information elsewhere during this lowering. I was thinking whether we can just get the dtype directly from x, but I think in general, we don't actually keep around dtype information at this stage of the lowering.

@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 13, 2023

Would you mind separating the formatting prs ? it makes it more difficult to review. If you need configuring vscode or something else let me know.

Yeah will do. I was testing on another H100 host so maybe I did something wrong in the linter configuration...

Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Sep 14, 2023
ghstack-source-id: 72bad6f
Pull Request resolved: #109168
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good modulo linting but I think horrace + elias are a better review on this

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update reduction heuristics as well ? What about triton templates ?

triton_config(size_hints, bs, 1),
triton_config(size_hints, 1, bs),
triton_config(
size_hints, 32, 32, min_elements_per_thread=min_elements_per_thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doesn't really matter, but maybe slightly less verbose as min_elem_per_thread

if len(size_hints) == 1:
if disable_pointwise_autotuning() and not (
config.max_autotune or config.max_autotune_pointwise
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for fewer changes you could consider triton_config = functools.partial(triton_config, min_elements_per_thread=min_elements_per_thread

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to rename triton_config to avoid lint errors..

_integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
_float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64)
_float_dtypes = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace this with calling dtype.is_floating_point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!


x_shape = (16, 16, 16)

with self.assertRaises(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to be more specific here re. the Exception type (and maybe text)? And below.

@drisspg drisspg self-requested a review September 19, 2023 01:23
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor Author

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @eellison @aakhundov !

@eellison Changes for reduction heuristics will be in the next PR. I'm still trying to figure out how to do it correctly for different reduction types (persistent_reduction and normal reduction, real reduction (like max) and fused_reduction_pointwise (like layer_norm)).

Triton template changes will come the last. For now we'll just rely on CuBLAS for Gemms since Triton H100 perf is not ideal.

_integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
_float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64)
_float_dtypes = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

if len(size_hints) == 1:
if disable_pointwise_autotuning() and not (
config.max_autotune or config.max_autotune_pointwise
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to rename triton_config to avoid lint errors..

Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 23, 2023

I'll merge this PR first to unblock fp8 related testing. Meanwhile, I'm working on adding scalar fp8 conversion support in the Triton repo, and re-visit the Pointwise TritonHeuristics related change after the fix.

@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 23, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 23, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 23, 2023
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/ipiszy@gmail.com/7/head branch September 26, 2023 14:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants