-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[pipelining] add type checking to _backward functions #140019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140019
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c31dde6 with merge base 0a0915f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fix #139405 cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
| handles.append(handle) | ||
|
|
||
| # Stage 0 inputs do not require grads? Should we skip in that case? | ||
| if all(tensor.requires_grad for tensor in input_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come this if condition can be removed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we no longer call stage_backward_input for the first stage anymore
| weight.grad += dw | ||
| # return grads in the original order weights were provided in | ||
| return weight_grads | ||
| return tuple(weight_grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these changes look pretty innocuous to me, but can you convince me that this change doesn't add any restriction or limitation to the user code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of any restrictions unless the user is explicitly was explicitly checking the type is list. But in terms of consistency, the autograd.grad() API (https://pytorch.org/docs/stable/generated/torch.autograd.grad.html#torch-autograd-grad) also returns a tuple, so this changes matches it better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
Clean up methods related to stage input/output shape verification which are no longer needed Pull Request resolved: #140418 Approved by: https://github.com/wconstab ghstack dependencies: #140019
fix pytorch#139405 Pull Request resolved: pytorch#140019 Approved by: https://github.com/wconstab
Clean up methods related to stage input/output shape verification which are no longer needed Pull Request resolved: pytorch#140418 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#140019
Stack from ghstack (oldest at bottom):
fix #139405
cc @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o