KEMBAR78
fix FakeTensor creation on noncontiguous subclasses by bdhirsh · Pull Request #124399 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124399

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0aac007 with merge base e16f1ee (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ion on noncontiguous subclasses"

[ghstack-poisoned]
for it: we need the FakeTensor to have accurate is_leaf information,
even though we don't actually plan to run autograd through the graph input.
"""
torch._C._forbid_in_autograd(tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD does this look ok to you as public API + the docs I wrote? Let me know if you think I should make it clearer that it's a sharp-edge API

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds fair as a public API for me. I'll let @soulitzer give his opinion in case he prefers to keep it private.
Either way, it needs testing in test_autograd!

}

void forbid_in_autograd(const Variable& self) {
TORCH_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a bit more error checking here to ensure that:

  • This is a leaf (no existing grad_fn)
  • This is not already part of a graph (no grad_accumulator)

self.defined(), "cannot call forbid_in_autograd() on undefined tensor");
auto new_grad_fn = std::shared_ptr<torch::autograd::Error>(
new torch::autograd::Error(
"Cannot backprop through Error node, file a bug in PyTorch"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot backprop through a Tensor that was marked as forbidden in backward.
Or something similar

torch._C._increment_version(tensor)


def forbid_in_autograd(tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forbid_in_backward() ?
It's not forbidden in forward mode AD or from interracting in a non-differentiable way with autograd.



def forbid_in_autograd(tensor):
"""Replaces the current tensor's grad_fn with an Error node.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I like the factuality and it makes it clear to me what this does, I think we need a bit more sugar coating for end users. See below.

Also maybe have a small note that this is an advanced API that we don't expect most users to use and we expect that detach() and no_grad() should be used by most users to locally disable autograd as discussed in https://pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation

Comment on lines 211 to 212
If the tensor was originally an autograd leaf (tensor.is_leaf == False),
setting the tensor's grad_fn to an error node will flip tensor.is_leaf to True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You inverted the True/False here

def forbid_in_autograd(tensor):
"""Replaces the current tensor's grad_fn with an Error node.
This effectively forbids the tensor from having a gradient computed during backward.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like the right intro for just above

If the tensor was originally an autograd leaf (tensor.is_leaf == False),
setting the tensor's grad_fn to an error node will flip tensor.is_leaf to True.
This is a convenient API used in torch.compile internals, when we need to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. note:: One example where this API is used is ...

@soulitzer
Copy link
Contributor

soulitzer commented Apr 19, 2024

Public API sounds okay, but preferably with a clear story around when we'd use this over existing APIs that do similar "forbid_in_autograd" things like .detach() or setting .requires_grad=False. (We also have the private torch._C._functions.DelayedError, which is an out-of-place version of this.) If it is difficult to have such examples, I don't mind keeping it private either for now, but no strong preference.

…ion on noncontiguous subclasses"

Fixes #124090, context on the issue




[ghstack-poisoned]
@bdhirsh bdhirsh changed the title add forbid_in_autograd api, use it to fix FakeTensor creation on noncontiguous subclasses fix FakeTensor creation on noncontiguous subclasses May 1, 2024
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 1, 2024

I updated the PR with @soulitzer's idea to not add a new API: it seems like torch._C._functions.DelayedError should do the job.

Adding that ErrorNode also caused some test failures, which made me realize that autograd.backward is broken in dynamo: #125287. I just fixed it directly in this PR.

@ezyang ezyang removed their request for review May 1, 2024 02:54
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 1, 2024

Hmm... the errors I'm getting are because y = torch._C._functions.DelayedError(1)(x) don't "propagate" the fact that x is a fake tensor (it returns a plain tensor)

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 1, 2024

oh it's probably because the SparseTensorImpl C++ subclasses don't override shallow_copy_and_detach properly...

We probably are not able to handle "tensor subclass holding a fake sparse tensor" today for other reasons, so I'm going to leave the sparse fakify logic alone and have it continue using clone for now.

bdhirsh added 2 commits May 1, 2024 08:56
Fixes #125287

Fixes #124090, context on the issue




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k ezyang msaroufim anijain2305 voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Fixes #125287

Fixes #124090, context on the issue




cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k ezyang msaroufim anijain2305 voznesenskym EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
@bdhirsh bdhirsh added the release notes: composability release notes category label May 1, 2024
pytorchmergebot pushed a commit that referenced this pull request May 1, 2024
pytorchmergebot pushed a commit that referenced this pull request May 1, 2024
…spatch__ (#123347)" (#125288)

Re-land of #123347.

The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around

This reverts commit 9d88339.

Pull Request resolved: #125288
Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin
ghstack dependencies: #124398, #124399, #124400
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
…spatch__ (#123347)" (#125288)

Re-land of #123347.

The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around

This reverts commit 9d88339.

Pull Request resolved: #125288
Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin
ghstack dependencies: #124398, #124399, #124400
@github-actions github-actions bot deleted the gh/bdhirsh/555/head branch June 4, 2024 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor Merged module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 release notes: composability release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants