KEMBAR78
[optim] Allow torch.float64 scalars for forloop + foreach implementations by janeyx99 · Pull Request #115841 · pytorch/pytorch · GitHub
Skip to content

Conversation

@janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Dec 14, 2023

Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:

  • Relax the constraint on fused adam(w) and allow torch.float64 scalars there
  • Allow performant mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order

Stack from ghstack (oldest at bottom):

@janeyx99 janeyx99 requested a review from albanD as a code owner December 14, 2023 17:32
@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Dec 14, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115841

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8260136 with merge base 0978482 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

janeyx99 added a commit that referenced this pull request Dec 14, 2023
@janeyx99 janeyx99 changed the title Allow torch.float64 scalars for foreach implementations [optim] Allow torch.float64 scalars for forloop + foreach implementations Dec 14, 2023
def _get_scalar_dtype(is_fused=None):
if is_fused:
return torch.float32
return torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember the story for default dtype vs default Tensor type. Do we need to check both here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we do, get_default_dtype seems to work as expected:
image

…implementations"

Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Dec 26, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

…implementations"

Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Dec 27, 2023
@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 27, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/117/head branch December 30, 2023 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: foreach_frontend release notes category release notes: optim topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants