-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[DTensor][XLA] Support Xla backend in distribute_tensor API #110275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110275
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d435177 with merge base 3ca81ae ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good and almost ready to merge! I left a few more comments about testing and and doc suggestions, we can land this once addressing them :)
0348a41 to
dff47ce
Compare
ada34d3 to
d435177
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, we can do a follow up PR to move most logic to pytorch/xla directly
|
|
||
|
|
||
| @with_xla | ||
| def xla_distribute_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something suggested by @bdhirsh that could further simplify the pytorch integration and be more flexible on your side: we can essentially move the whole logic to the pytorch/xla and expose a xla_distribute_tensor in the pytorch/xla package, and then in the pytorch side we can just use this API directly without having this _xla.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thanks Brian @bdhirsh
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except one comment.
| ``` | ||
| """ | ||
| assert dt_mesh.size() == xr.global_runtime_device_count() | ||
| return Mesh( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unclear to me how HybridMesh would work here. But we can follow up later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question - HybridMesh is a wrapper class class HybridMesh(Mesh) for multi-pod (ici & dcn shapes), which still uses xs.Mesh. Probably not relevant, yet, in DTensor and we should define a new mesh type, like HybridDeviceMesh and integrate there when the time comes for us to follow-up. Thanks @alanwaketan
| ) -> None: | ||
| if TORCH_XLA_INITIALIZED: | ||
| # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. | ||
| os.environ["XLA_USE_SPMD"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason we don;t use xr.use_spmd() directly today?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, we should do some more testing and possibly update the error message. For instance, we should block if we have pre-existing non-xla or non sharded data and tell the user about it. We will add/do some more testing in the downstream before deprecating the flag.
| ), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. " | ||
| "Clear the existing sharding annotation first, by callling torch_xla.experimental.xla_sharding.clear_sharding API." | ||
| global_tensor = tensor.global_tensor # type:ignore[attr-defined] | ||
| assert global_tensor is not None, "distributing a tensor should not be None" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a
else:
raise ValueError
here? Can we handle non XLASHardedTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We handle torch.Tensor or XLAShardedTensor (its global representation). We decided to block DTensor to be consistent with DTensor's eager api.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm
…110275) This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: pytorch#110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
…110275) This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: pytorch#110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
This addresses #92909 , and enable XLA backend support for
distribute_tensorAPI.Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.
cc @bdhirsh @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @wanchaol