-
Notifications
You must be signed in to change notification settings - Fork 25.7k
fix FakeTensor creation on noncontiguous subclasses #124399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ontiguous subclasses [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124399
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0aac007 with merge base e16f1ee ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ion on noncontiguous subclasses" [ghstack-poisoned]
torch/autograd/graph.py
Outdated
| for it: we need the FakeTensor to have accurate is_leaf information, | ||
| even though we don't actually plan to run autograd through the graph input. | ||
| """ | ||
| torch._C._forbid_in_autograd(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD does this look ok to you as public API + the docs I wrote? Let me know if you think I should make it clearer that it's a sharp-edge API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds fair as a public API for me. I'll let @soulitzer give his opinion in case he prefers to keep it private.
Either way, it needs testing in test_autograd!
torch/csrc/autograd/variable.cpp
Outdated
| } | ||
|
|
||
| void forbid_in_autograd(const Variable& self) { | ||
| TORCH_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need a bit more error checking here to ensure that:
- This is a leaf (no existing grad_fn)
- This is not already part of a graph (no grad_accumulator)
torch/csrc/autograd/variable.cpp
Outdated
| self.defined(), "cannot call forbid_in_autograd() on undefined tensor"); | ||
| auto new_grad_fn = std::shared_ptr<torch::autograd::Error>( | ||
| new torch::autograd::Error( | ||
| "Cannot backprop through Error node, file a bug in PyTorch"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot backprop through a Tensor that was marked as forbidden in backward.
Or something similar
torch/autograd/graph.py
Outdated
| torch._C._increment_version(tensor) | ||
|
|
||
|
|
||
| def forbid_in_autograd(tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forbid_in_backward() ?
It's not forbidden in forward mode AD or from interracting in a non-differentiable way with autograd.
torch/autograd/graph.py
Outdated
|
|
||
|
|
||
| def forbid_in_autograd(tensor): | ||
| """Replaces the current tensor's grad_fn with an Error node. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I like the factuality and it makes it clear to me what this does, I think we need a bit more sugar coating for end users. See below.
Also maybe have a small note that this is an advanced API that we don't expect most users to use and we expect that detach() and no_grad() should be used by most users to locally disable autograd as discussed in https://pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation
torch/autograd/graph.py
Outdated
| If the tensor was originally an autograd leaf (tensor.is_leaf == False), | ||
| setting the tensor's grad_fn to an error node will flip tensor.is_leaf to True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You inverted the True/False here
torch/autograd/graph.py
Outdated
| def forbid_in_autograd(tensor): | ||
| """Replaces the current tensor's grad_fn with an Error node. | ||
| This effectively forbids the tensor from having a gradient computed during backward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds like the right intro for just above
torch/autograd/graph.py
Outdated
| If the tensor was originally an autograd leaf (tensor.is_leaf == False), | ||
| setting the tensor's grad_fn to an error node will flip tensor.is_leaf to True. | ||
| This is a convenient API used in torch.compile internals, when we need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. note:: One example where this API is used is ...
|
Public API sounds okay, but preferably with a clear story around when we'd use this over existing APIs that do similar "forbid_in_autograd" things like |
…ion on noncontiguous subclasses" Fixes #124090, context on the issue [ghstack-poisoned]
|
I updated the PR with @soulitzer's idea to not add a new API: it seems like Adding that ErrorNode also caused some test failures, which made me realize that |
|
Hmm... the errors I'm getting are because |
|
oh it's probably because the SparseTensorImpl C++ subclasses don't override shallow_copy_and_detach properly... We probably are not able to handle "tensor subclass holding a fake sparse tensor" today for other reasons, so I'm going to leave the sparse fakify logic alone and have it continue using clone for now. |
Fixes #125287 Fixes #124090, context on the issue cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k ezyang msaroufim anijain2305 voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
Fixes #125287 Fixes #124090, context on the issue cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k ezyang msaroufim anijain2305 voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
…or is noncontiguous (#124400) Fixes #124397 Pull Request resolved: #124400 Approved by: https://github.com/ezyang, https://github.com/yoyoyocmu ghstack dependencies: #124398, #124399
…spatch__ (#123347)" (#125288) Re-land of #123347. The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around This reverts commit 9d88339. Pull Request resolved: #125288 Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin ghstack dependencies: #124398, #124399, #124400
Fixes pytorch#125287 Fixes pytorch#124090, context on the issue Pull Request resolved: pytorch#124399 Approved by: https://github.com/soulitzer ghstack dependencies: pytorch#124398
…or is noncontiguous (#124400) Fixes #124397 Pull Request resolved: #124400 Approved by: https://github.com/ezyang, https://github.com/yoyoyocmu ghstack dependencies: #124398, #124399
…spatch__ (#123347)" (#125288) Re-land of #123347. The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around This reverts commit 9d88339. Pull Request resolved: #125288 Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin ghstack dependencies: #124398, #124399, #124400
Fixes #125287
Fixes #124090, context on the issue
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @ezyang @msaroufim @anijain2305 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng