KEMBAR78
[dtensor] relax device_mesh argument constraint in local_map by wanchaol · Pull Request #157049 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jun 27, 2025

This PR relaxes the device_mesh argument constraint in the local_map API. The current restriction is too strict, i.e. all the input arguments must have the same device mesh if they are DTensors. But many times user might want to pass in DTensors to this function that lives on different device mesh, i.e. weight and activation could live in different device mesh.

When using the local_map, we are extracting the local tensors from DTensors, and as long as the placements user specified match with the actual DTensor placements, user knows clearly that the inputs are intended to live in different mesh. So this PR removes the same mesh check and update doc to clearly document the behavior.

The device_mesh argument now serves for a main purpose, allow user to specify the device_mesh for the output DTensor reconstruction

Fixes #ISSUE_NUMBER

cc @H-Huang @awgu @fegin @fduwjj @wz337 @wconstab @d4l3k

This PR relaxes the device_mesh argument constraint in the local_map
API. The current restriction is too strict, i.e. all the input arguments
must have the same device mesh if they are DTensors. But many times user
might want to pass in DTensors to this function that lives on different
device mesh, i.e. weight and activation could live in different device
mesh.

When using the local_map, we are extracting the local tensors from
DTensors, and as long as the placements user specified match with the
actual DTensor placements, user knows clearly that the inputs are
intended to live in different mesh. So this PR removes the same mesh
check and update doc to clearly document the behavior.

The `device_mesh` argument now serves for a main purpose, allow user to
specify the device_mesh for the output DTensor reconstruction
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157049

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 1 Unrelated Failure

As of commit aa77732 with merge base 81759af (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 27, 2025
@wanchaol wanchaol requested review from Chillee, XilunWu, awgu and zpcore June 27, 2025 21:35
@wanchaol wanchaol added release notes: distributed (dtensor) release notes category ciflow/trunk Trigger trunk jobs on your pull request labels Jun 27, 2025
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me!

@zpcore
Copy link
Member

zpcore commented Jun 27, 2025

LGTM!

Overall this seems to be a hack to get around the cross mesh assertion. We basically leave the responsibility to the user to make sure the shape is correct so that we don't need to write the complicated "cross mesh sharding propagation rule". My concern is that the code can easily fail once we run the model on top of a different world_size or input.

@wanchaol
Copy link
Collaborator Author

Overall this seems to be a hack to get around the cross mesh assertion. We basically leave the responsibility to the user to make sure the shape is correct so that we don't need to write the complicated "cross mesh sharding propagation rule". My concern is that the code can easily fail once we run the model on top of a different world_size or input.

@zpcore The role of the local_map API is to give back control to user to operate directly on the input DTensor's local shard, and it returns a DTensor according to user provided out_placements, so that the user function that wraps with local_map could do whatever user want. For the wrapped user function, there should be no "cross mesh sharding propagation rule" be written, as the in_placements and out_placements has already been specified by user when using the local_map API/contract (and it act like a "propagation rule" in some sense).

I think the previous assertion asserts the different input DTensors must have the same mesh is for extra safety in the beginning, but it turns out not necessary. The restriction is that it would limit the expressiveness that the in_placements user can specify (suppose user pass two DTensors, one is on 1-D mesh, the other is on 2-D mesh, even if user can specify like this, the previous assertion would error out to user). Let me know if that does not make sense to you.

@wanchaol
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@wanchaol
Copy link
Collaborator Author

@pytorchbot merge -f "ci failure not related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants