KEMBAR78
[Distributed] add FP8 support to NaN checker by kwen2501 · Pull Request #135891 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 12, 2024

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 12, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135891

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 27e3426 with merge base 0216936 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Sep 12, 2024
ghstack-source-id: 8054946
Pull Request resolved: #135891
@rghosh08
Copy link

PR Reviewer Guide 🔍

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
🔒 No security concerns identified
⚡ Key issues to review

Code Smell
The comment in line 370 suggests that filling values into a FP8 tensor is currently not supported. This might be worth considering for future improvements.

Code Smell
The AT_DISPATCH_FLOATING_TYPES_AND4 macro now includes support for FP8 types. However, the comment above it still mentions support for only Half and BFloat16 types. This might be worth updating for consistency.

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except for missing optimized template for the f8 kernel. Can you determine if its worth adding before deciding to land this as is or optimize further?

Adding support for `torch.float8_e4m3fn` and `torch.float8_e5m2`

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 13, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 14, 2024
We check 8 x FP8 simultaneously, at size of 8 bytes.

Pull Request resolved: #135961
Approved by: https://github.com/yifuwang, https://github.com/Skylion007
ghstack dependencies: #135891
pytorchmergebot pushed a commit that referenced this pull request Sep 15, 2024
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).

Made `HasNanFP8x8` a template so that it is extendable based on dtype.

Pull Request resolved: #136115
Approved by: https://github.com/Skylion007
ghstack dependencies: #135891, #135961
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Adding support for `torch.float8_e4m3fn` and `torch.float8_e5m2`

Pull Request resolved: pytorch#135891
Approved by: https://github.com/wconstab
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
We check 8 x FP8 simultaneously, at size of 8 bytes.

Pull Request resolved: pytorch#135961
Approved by: https://github.com/yifuwang, https://github.com/Skylion007
ghstack dependencies: pytorch#135891
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).

Made `HasNanFP8x8` a template so that it is extendable based on dtype.

Pull Request resolved: pytorch#136115
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#135891, pytorch#135961
@github-actions github-actions bot deleted the gh/kwen2501/59/head branch October 14, 2024 06:24
KnAwnime pushed a commit to KnAwnime/Biblioteka that referenced this pull request Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants