KEMBAR78
[DTensor] Support matmul in inference_mode by kwen2501 · Pull Request #142197 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Dec 6, 2024

Stack from ghstack (oldest at bottom):

Fixes #142190 .

The solution is to add a decompose_handler for aten.matmul, similar to how we handle aten.linear.
With the decomposition, aten.matmul becomes aten.mm which has sharding strategy registered with DTensor.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142197

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ac1c07d with merge base 61dc5e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Dec 6, 2024
kwen2501 added a commit that referenced this pull request Dec 6, 2024
ghstack-source-id: 3dd1f5f
Pull Request resolved: #142197
@kwen2501 kwen2501 requested review from bdhirsh and wz337 December 6, 2024 01:28
@kwen2501 kwen2501 added the release notes: distributed (dtensor) release notes category label Dec 6, 2024
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except the var naming.

Comment on lines +132 to +133
dx = distribute_tensor(x, device_mesh, [Replicate()])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i remember that people complain the var name dX dA indicates gradients. Names such as x_dist A_distare preferred. cc @awgu

@kwen2501
Copy link
Contributor Author

kwen2501 commented Dec 6, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 6, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@wz337 wz337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I learned how to let DTensor decompose through this diff.

pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
Fixes #142190 .

The solution is to add a `decompose_handler` for `aten.matmul`, similar to how we handle `aten.linear`.
With the decomposition, `aten.matmul` becomes `aten.mm` which has sharding strategy registered with DTensor.

Pull Request resolved: #142197
Approved by: https://github.com/XilunWu, https://github.com/wz337
@github-actions github-actions bot deleted the gh/kwen2501/110/head branch January 6, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants