-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[optim] Allow torch.float64 scalars for forloop + foreach implementations #115841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115841
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8260136 with merge base 0978482 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| def _get_scalar_dtype(is_fused=None): | ||
| if is_fused: | ||
| return torch.float32 | ||
| return torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember the story for default dtype vs default Tensor type. Do we need to check both here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…implementations" Should allow for uses cases mentioned in #110940 This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement. Next steps: - Relax the constraint on fused adam(w) and allow torch.float64 scalars there - Allow _performant_ mixed dtypes in foreach (a bigger project in itself). This PR will conflict with my other PRs, I will figure out a landing order [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
…implementations" Should allow for uses cases mentioned in #110940 This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement. Next steps: - Relax the constraint on fused adam(w) and allow torch.float64 scalars there - Allow _performant_ mixed dtypes in foreach (a bigger project in itself). This PR will conflict with my other PRs, I will figure out a landing order [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |

Should allow for uses cases mentioned in #110940
This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.
Next steps:
This PR will conflict with my other PRs, I will figure out a landing order
Stack from ghstack (oldest at bottom):