KEMBAR78
[autograd] match 0-dim gradients device type regardless of subclassness by xmfan · Pull Request #160165 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Aug 8, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160165

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 83763d7 with merge base ba37f58 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Aug 8, 2025
@xmfan xmfan changed the title [autograd] match 0-dim gradients regardless of subclassness [autograd] match 0-dim gradients device type regardless of subclassness Aug 8, 2025
… subclassness"


Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it?

FIXES #160084

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 8, 2025
@xmfan xmfan added the module: autograd Related to torch.autograd, and the autograd engine in general label Aug 8, 2025
@xmfan xmfan marked this pull request as ready for review August 8, 2025 16:11
@xmfan xmfan requested review from albanD and soulitzer as code owners August 8, 2025 16:11
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. A bit more testing is needed.

if (grad.dim() == 0) {
grad = grad.to(metadata.device());
} else {
// quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang now that we do have proper wrapper subclass, should we remove this as we don't have to use meta device for all subclasses anymore

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 8, 2025
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit, SGTM otherwise

Comment on lines 12411 to 12412

model = RegressionModel()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = RegressionModel()
# Keep the model on cpu as we do want to test the mixed cpu/accelerator behavior here
model = RegressionModel()

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 8, 2025
@xmfan
Copy link
Member Author

xmfan commented Aug 9, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 9, 2025

@staticmethod
def __new__(cls, elem, *args, **kwargs):
# Wrong device here!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD This test intentionally diverges the outer tensor's device from the inner tensor's device, is this something we still want to support? Assuming it's not, the test wouldn't make sense anymore if we kept the two in sync so I just removed it

@xmfan
Copy link
Member Author

xmfan commented Aug 11, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@xmfan xmfan added the release notes: autograd release notes category label Aug 11, 2025
@xmfan
Copy link
Member Author

xmfan commented Aug 11, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Details for Dev Infra team Raised by workflow job

@xmfan
Copy link
Member Author

xmfan commented Aug 11, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@xmfan
Copy link
Member Author

xmfan commented Aug 12, 2025

@pytorchbot cherry-pick --onto release/2.8 -c critical

@pytorchbot
Copy link
Collaborator

Cherry picking #160165

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x c8205cb35435f39d2c26f6c94b45e4adeb6dcb23 returned non-zero exit code 1

Auto-merging test/dynamo/test_repros.py
CONFLICT (content): Merge conflict in test/dynamo/test_repros.py
Auto-merging test/test_autograd.py
Auto-merging test/test_python_dispatch.py
error: could not apply c8205cb3543... [autograd] match 0-dim gradients device type regardless of subclassness (#160165)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

samanklesaria pushed a commit to samanklesaria/pytorch that referenced this pull request Aug 12, 2025
…ss (pytorch#160165)

Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it?

FIXES pytorch#160084

Pull Request resolved: pytorch#160165
Approved by: https://github.com/ezyang, https://github.com/albanD
@github-actions github-actions bot deleted the gh/xmfan/276/head branch September 11, 2025 02:10
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ss (pytorch#160165)

Not sure if there some subclasses where the outer.dim() == 0 but you wouldn't want to move it?

FIXES pytorch#160084

Pull Request resolved: pytorch#160165
Approved by: https://github.com/ezyang, https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: autograd Related to torch.autograd, and the autograd engine in general module: dynamo release notes: autograd release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants