-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add support for float8_e4m3fnuz and _e5m2fnuz #107586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107586
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 89852c6 with merge base 79e3833 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
49b0849 to
34871d7
Compare
cc7fd13 to
ed84796
Compare
|
@albanD @seemethere – would it be possible to get an exception for the 2000 LOC limit? I'm only 6 lines over and quite a lot of it is just registering the types. |
ed84796 to
f889d6d
Compare
|
I've removed my changes to add these and existing types to |
|
Thanks for working on this! At a high level this looks good, and we will need to do a more detailed review. Given that the branch cut for v2.1 is only a few days away and the high LOC of this change and our experience with landing previous "new dtype" PRs, I would expect that Pytorch v2.2 would be a reasonable target for eventually getting this in. |
c10/util/Float8_e4m3fnuz.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks reasonable. Just curious, what made you choose a LUT over bit shifting?
also, do we expect the hardware to support an accelerated version of these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no particular reason to use a LUT here, I can change to bit shifting if needed!
Yep, Graphcore's C600 hardware has dedicated instructions for which can be used to convert to and from both these FP8 types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we run any perf benchmarks/binary size consideration for having this lookup table embedded over and over in every op that will need to convert F8E4M3FNUZ to float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will move the LUT in to the .cpp file.
|
The dtype pieces looks reasonable to me! Wondering if you could share some context on which hardware supports these float8 flavors now, and which hardware is expected to support this in the future? cc @malfet , would you be up for a more detailed review on the framework pieces figuring out the best way to land this? |
|
Thanks for the review @vkuzo! I'll make the change to take those list of dtypes for the test parameters out in to a constant.
Graphcore's current C600 card support these types at the hardware level. It has instructions in the Tile ISA to perform common operations directly on FP8 data, as well as convert between types. |
1dd6bd9 to
cac8d42
Compare
|
I've rebased and added back my TypeInfo changes as I see that TypeInfo support has been added for the other FP8 types too, but this might make it go over the LOC limit again if the number of lines removed is also included in that count. |
|
|
||
| if (f_bits >= fnuz_max) { | ||
| // NaN -- sign bit set to 1, rest 0s | ||
| return 0x80; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
table here clips float values more than FNUZ_MAX to FLT_MAX.
https://onnx.ai/onnx/technical/float8.html#cast
Is there reason behind using NaNs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that the existing casting code for e5m2 and e4m3fn types are also implemented without any saturation, so I chose to do the same to match that behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fear it would lead to having more NaNs when doing inference. Ideally there should be a flag for the users to set which behaviour they want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple comments to help with your rocm CI failure. Hope this helps! :)
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Follow up to #107586. Pull Request resolved: #115214 Approved by: https://github.com/peterbell10
Follow up to #107586. Pull Request resolved: #115214 Approved by: https://github.com/peterbell10, https://github.com/malfet
Follow up to pytorch#107586. Pull Request resolved: pytorch#115214 Approved by: https://github.com/peterbell10, https://github.com/malfet
This PR relates to the feature in this feature submission. It has been based on #104242 which adds similar float8 types.
These new types added in this PR are described in the paper at https://arxiv.org/abs/2206.02915. A brief description and comparison of the types with other float8 types can be also found in the OpenXLA RFC.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @EikanWang @albanD