-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[PP] Remove modifications to autograd nodes in ZB #136678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136678
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c9bbe2d with merge base 9992084 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc XilunWu awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
||
# backward of loss with respect to weights | ||
dweights = stage_backward_weight(mod.parameters(), param_groups) | ||
stage_backward_weight(mod.parameters(), param_groups, retain_graph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why retain graph=true? Is it because we reuse the same graph for subsequent micro batches? Or does each micro batch have its own graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is a test. Never mind
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
||
def reverse_closure( | ||
roots: List[Node], target_nodes: Set[Node] | ||
roots: List[Node], target_nodes: Set[Node], reverse_edges_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add doc to the reverse_edges_dict
argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do!
def stage_backward_weight( | ||
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] | ||
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]], retain_graph=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It the retain_graph
flag is just for test purpose, shall we make it _retain_graph
and add a banner saying "Test only; don't use"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added retain_graph
to better align with the .backward()
API in the case that someone wants to perform multiple backwards (double backward) and accumulate the gradients. I was using this in testing and I dont think anyone besides us is using the stage_backward_input
and stage_backward_weight
, but they dont necessarily need to only apply to stages and could be used as a more general API.
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 2169168 Pull Request resolved: pytorch/pytorch#136678
Stack from ghstack (oldest at bottom):
cc @XilunWu @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o