-
Notifications
You must be signed in to change notification settings - Fork 25.7k
stop non-differentiable values from being materialized in aotautograd #110721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110721
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a10ef1b with merge base e3bf500 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…aotautograd" [ghstack-poisoned]
…aotautograd" [ghstack-poisoned]
4e59fb8 to
6d027dc
Compare
…aotautograd" [ghstack-poisoned]
…aotautograd" [ghstack-poisoned]
…aotautograd" [ghstack-poisoned]
…aotautograd" [ghstack-poisoned]
|
This seems plausible but you still have test errors |
|
Where do you see test errors 🤔 |
|
I'm letting Brian review this |
|
@pytorchbot merge |
|
This PR needs to be approved by an authorized maintainer before merge. |
| inp_tangents_filtered = [ | ||
| x | ||
| for x, info_idx in zip(inp_tangents, mutated_inp_indices) | ||
| if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think the bit of the filter on `input_info[info_idx].requires_grad should be unnecessary now?
Previously, inp_tangents corresponded to "every forward input that had a (data or metadata) mutation". We had to filter this down to the actual inputs to our backward graph that correspond to user forward inputs, which are fw inputs that had a data mutation and require grad. (Actually - double checking, I think we're already correctly filtering out metadata mutations).
But now that you're filtering down to fw inputs that have a mutation and require grad even earlier (as part of constructing the joint fw/bw), we should only have to filter out inputs with metadata-only mutations in this check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future discussion, this isn't true. In this case, inp_tangents represents all of the inputs to the backwards pass, which will include values that corresponds to tensors that we don't actually trace with (for example, outputs of the forwards pass that are nondifferentiable).
So we still need to filter them out in this place - the logic is pretty analogous to the existing ones.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |


Stack from ghstack (oldest at bottom):