-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_
#139662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_
#139662
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139662
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 890bdcd with merge base 6bdbc86 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rad_norm_`" Fixes #139467 Add `nn.utils.get_grad_norm` and `nn.utils.clip_grads_` [ghstack-poisoned]
…rad_norm_`" Fixes #139467 Add `nn.utils.get_grad_norm` and `nn.utils.scale_grads_` Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative [ghstack-poisoned]
nn.utils.clip_grad_norm_nn.utils.clip_grad_norm_
nn.utils.clip_grad_norm_nn.utils.clip_grad_norm_
… in `nn.utils.clip_grad_norm_`" Fixes #139467 Add `nn.utils.get_grad_norm` and `nn.utils.scale_grads_` Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_grad_norm` and then `nn.utils.scale_grads_` . Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_grad_norm` and then `nn.utils.scale_grads_` . Since `clip_grad_norm_` calls into these two new ops, the prior testing applies. Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative cc wconstab zijian-hu [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" [WIP] Going to change `get_grad_norm` to `get_total_norm` so it can be used with any list of Tensors Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_grad_norm` and then `nn.utils.scale_grads_` . Since `clip_grad_norm_` calls into these two new ops, the prior testing applies. Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative cc wconstab zijian-hu [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" [WIP] Going to change `get_grad_norm` to `get_total_norm` so it can be used with any list of Tensors Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_grad_norm` and then `nn.utils.scale_grads_` . Since `clip_grad_norm_` calls into these two new ops, the prior testing applies. Chose `scale_grads_` instead of `clip_grads` as there is also `clip_grad_value` which does clamping, so `clip_grads_` did not feel representative cc wconstab zijian-hu [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.scale_grads_` . `clip_grad_norm_` now calls into these two new ops, `get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from awgu) cc wconstab zijian-hu [ghstack-poisoned]
… in `nn.utils.clip_grad_norm_`" Fixes #139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops, `get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from awgu) cc wconstab zijian-hu [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
| from . import parametrizations, rnn, stateless | ||
| from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ | ||
| from .clip_grad import ( | ||
| _clip_grads_with_norm_ as clip_grads_with_norm_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, is there a particular reason the APIs are prepended with _ if they are going to be public anyways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to avoid the public API docs test complaining that torch.nn.utils.clip_grad.{foo} is not documented (when it is already documented and publicly exposed as torch.nn.utils.{foo})
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test Details for Dev Infra teamRaised by workflow job |
|
Could not find a version that satisfies the requirement jinja2 (from torch) (from versions: none) in linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test (gh) is unrelated |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ls.clip_grad_norm_` (pytorch#139662) Fixes pytorch#139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops, `get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from @awgu) Pull Request resolved: pytorch#139662 Approved by: https://github.com/H-Huang
Fixes #139467
Refactor
nn.utils.clip_grad_norm_intonn.utils.get_total_normand thennn.utils.clip_grads_with_norm_.clip_grad_norm_now calls into these two new ops,get_total_normis generalized (rather thanget_grad_normdue to the discussion on the issue from @awgu)Stack from ghstack (oldest at bottom):
nn.utils.clip_grad_norm_#139662cc @wconstab @zijian-hu