-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Support complex numbers in DTensor redistribute #157329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157329
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit d1d069e with merge base 070aa59 ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. ghstack-source-id: da2fdae Pull Request resolved: #157329
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k Differential Revision: [D77564148](https://our.internmc.facebook.com/intern/diff/D77564148) [ghstack-poisoned]
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k Differential Revision: [D77564148](https://our.internmc.facebook.com/intern/diff/D77564148) [ghstack-poisoned]
# TODO(whc) it appears complex-allreduce is already being supported becuase this test passes, | ||
# but I did not see where the support is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this support is only for NCCL and some reduceOp: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L4393-L4400
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, thanks for pointing that out. That was the last piece i was missing.
The new piece of info was that CI failed this test for GLOO bc of gloo not supporting allreduce. I therefore added the complex support at Functional.cpp for allreduce, fixing the gloo test, and now I understand why nccl was already passing without it. let me remove this TODO.
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test change LGTM but the functional collective part needs small change. Thanks Will for adding complex number support!
auto input_real = input.is_complex() ? at::view_as_real(input) : input; | ||
auto output = input_real.clone(at::MemoryFormat::Contiguous); | ||
auto output_ret = | ||
all_reduce_(output, std::move(reduce_op), std::move(group_name)); | ||
return input.is_complex() ? at::view_as_complex(output_ret) : output_ret; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the same logic (check reduce_op) (https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L4393-L4400) because this approach only holds numeric correctness for these 4 ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't know what the best choice is here to reuse the existing helper. I ended up writing a new string-matching check. Lmk if you think its worthwhile to do more refactoring to use the same helper from ProcessGroupNccl, but note that to use that i'd have to first convert into ReduceOp enum, and the conversion helper in this file is incomplete for some reason. (premul_sum)
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
Add complex number unwrapping in functional collectives used by DTensor. Complex tensors are not directly supported by underlying comm kernels (e.g. nccl) but complex tensors can be viewed as real tensors of a higher rank (added size-2 tensor dim represents real vs im component). Collective output is then viewed as complex to restore the original/expected shape and dtype. ghstack-source-id: 052ebcd Pull Request resolved: #157329
self.assertEqual(new_tensor.stride(), new_meta_tensor.stride()) | ||
|
||
|
||
instantiate_parametrized_tests(RedistributeTest) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to myself: use instantiate_parametrized_tests
in DTensor tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to unblock. We can address the comment in follow-up PR
TORCH_CHECK( | ||
// TODO - ideally use 'to_reduce_op' helper but it currently errors on | ||
// premul_sum | ||
reduce_op == "sum" || reduce_op == "avg" || reduce_op == "premul_sum" || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it would be good to reuse complexViewAsRealAllowed
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i agree but it is not trivial.
- convert string to enum: there is a helper in this file, but it is not complete. if i completed it by filling in missing premul_sum, i would be affecting the behavior of other ops using this helper
- then i could refactor the complexViewAsRealAllowed out of ProcessGroupNccl.cpp and make it a util and use it.
Happy to do this in another PR, lmk if you have thoughts about (1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Low prio though.
lmk if you have thoughts about (1)
I think you're right. And I don't have a good solution either that can ensure consistency among helpers and potential extension of ReduceOp enum.
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-13), trunk / win-vs2022-cpu-py3 / test (default, 1, 3, lf.windows.4xlarge.nonephemeral) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Add complex number unwrapping in functional collectives used by DTensor.
Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k