KEMBAR78
[DTensor][XLA] Support Xla backend in distribute_tensor API by yeounoh · Pull Request #110275 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Sep 29, 2023

This addresses #92909 , and enable XLA backend support for distribute_tensor API.

Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.

cc @bdhirsh @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @wanchaol

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110275

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d435177 with merge base 3ca81ae (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good and almost ready to merge! I left a few more comments about testing and and doc suggestions, we can land this once addressing them :)

@yeounoh yeounoh force-pushed the distributed_tensor_xla_api branch 2 times, most recently from 0348a41 to dff47ce Compare October 20, 2023 04:21
@yeounoh yeounoh force-pushed the distributed_tensor_xla_api branch from ada34d3 to d435177 Compare October 20, 2023 06:34
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, we can do a follow up PR to move most logic to pytorch/xla directly



@with_xla
def xla_distribute_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something suggested by @bdhirsh that could further simplify the pytorch integration and be more flexible on your side: we can essentially move the whole logic to the pytorch/xla and expose a xla_distribute_tensor in the pytorch/xla package, and then in the pytorch side we can just use this API directly without having this _xla.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks Brian @bdhirsh

@yeounoh
Copy link
Contributor Author

yeounoh commented Oct 20, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except one comment.

```
"""
assert dt_mesh.size() == xr.global_runtime_device_count()
return Mesh(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear to me how HybridMesh would work here. But we can follow up later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - HybridMesh is a wrapper class class HybridMesh(Mesh) for multi-pod (ici & dcn shapes), which still uses xs.Mesh. Probably not relevant, yet, in DTensor and we should define a new mesh type, like HybridDeviceMesh and integrate there when the time comes for us to follow-up. Thanks @alanwaketan

) -> None:
if TORCH_XLA_INITIALIZED:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we don;t use xr.use_spmd() directly today?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, we should do some more testing and possibly update the error message. For instance, we should block if we have pre-existing non-xla or non sharded data and tell the user about it. We will add/do some more testing in the downstream before deprecating the flag.

), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
"Clear the existing sharding annotation first, by callling torch_xla.experimental.xla_sharding.clear_sharding API."
global_tensor = tensor.global_tensor # type:ignore[attr-defined]
assert global_tensor is not None, "distributing a tensor should not be None"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a

else:
  raise ValueError

here? Can we handle non XLASHardedTensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We handle torch.Tensor or XLAShardedTensor (its global representation). We decided to block DTensor to be consistent with DTensor's eager api.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…110275)

This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API.

Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.

Pull Request resolved: pytorch#110275
Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…110275)

This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API.

Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.

Pull Request resolved: pytorch#110275
Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag module: xla Related to XLA support open source release notes: distributed (dtensor) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants