KEMBAR78
Support complex numbers in DTensor redistribute by wconstab · Pull Request #157329 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jun 30, 2025

Stack from ghstack (oldest at bottom):

Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k

Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157329

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit d1d069e with merge base 070aa59 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jun 30, 2025
wconstab added a commit that referenced this pull request Jun 30, 2025
Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

ghstack-source-id: da2fdae
Pull Request resolved: #157329
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 30, 2025
Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

Differential Revision: [D77564148](https://our.internmc.facebook.com/intern/diff/D77564148)

[ghstack-poisoned]
Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

Differential Revision: [D77564148](https://our.internmc.facebook.com/intern/diff/D77564148)

[ghstack-poisoned]
Comment on lines 130 to 131
# TODO(whc) it appears complex-allreduce is already being supported becuase this test passes,
# but I did not see where the support is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, thanks for pointing that out. That was the last piece i was missing.

The new piece of info was that CI failed this test for GLOO bc of gloo not supporting allreduce. I therefore added the complex support at Functional.cpp for allreduce, fixing the gloo test, and now I understand why nccl was already passing without it. let me remove this TODO.

Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k


[ghstack-poisoned]
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test change LGTM but the functional collective part needs small change. Thanks Will for adding complex number support!

Comment on lines +84 to +88
auto input_real = input.is_complex() ? at::view_as_real(input) : input;
auto output = input_real.clone(at::MemoryFormat::Contiguous);
auto output_ret =
all_reduce_(output, std::move(reduce_op), std::move(group_name));
return input.is_complex() ? at::view_as_complex(output_ret) : output_ret;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the same logic (check reduce_op) (https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L4393-L4400) because this approach only holds numeric correctness for these 4 ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know what the best choice is here to reuse the existing helper. I ended up writing a new string-matching check. Lmk if you think its worthwhile to do more refactoring to use the same helper from ProcessGroupNccl, but note that to use that i'd have to first convert into ReduceOp enum, and the conversion helper in this file is incomplete for some reason. (premul_sum)

Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k


[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 1, 2025
Add complex number unwrapping in functional collectives used by DTensor.

Complex tensors are not directly supported by underlying comm kernels
(e.g. nccl) but complex tensors can be viewed as real tensors of a
higher rank (added size-2 tensor dim represents real vs im component).
Collective output is then viewed as complex to restore the
original/expected shape and dtype.

ghstack-source-id: 052ebcd
Pull Request resolved: #157329
self.assertEqual(new_tensor.stride(), new_meta_tensor.stride())


instantiate_parametrized_tests(RedistributeTest)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to myself: use instantiate_parametrized_tests in DTensor tests

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock. We can address the comment in follow-up PR

TORCH_CHECK(
// TODO - ideally use 'to_reduce_op' helper but it currently errors on
// premul_sum
reduce_op == "sum" || reduce_op == "avg" || reduce_op == "premul_sum" ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be good to reuse complexViewAsRealAllowed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree but it is not trivial.

  1. convert string to enum: there is a helper in this file, but it is not complete. if i completed it by filling in missing premul_sum, i would be affecting the behavior of other ops using this helper
  2. then i could refactor the complexViewAsRealAllowed out of ProcessGroupNccl.cpp and make it a util and use it.

Happy to do this in another PR, lmk if you have thoughts about (1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low prio though.

lmk if you have thoughts about (1)

I think you're right. And I don't have a good solution either that can ensure consistency among helpers and potential extension of ReduceOp enum.

@wconstab
Copy link
Contributor Author

wconstab commented Jul 2, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-13), trunk / win-vs2022-cpu-py3 / test (default, 1, 3, lf.windows.4xlarge.nonephemeral)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@wconstab wconstab deleted the gh/wconstab/421/head branch July 3, 2025 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants