KEMBAR78
[dtensor] support convolution ops by KingsleyLiu-NV · Pull Request #113123 · pytorch/pytorch · GitHub
Skip to content

Conversation

@KingsleyLiu-NV
Copy link
Contributor

@KingsleyLiu-NV KingsleyLiu-NV commented Nov 7, 2023

This PR creates a prototype of training convolutional neural networks based on DTensor.

  • Register required ops and implement operator dispatch
  • Add unit tests and example

Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration).

Reference log (on 2xA100 GPU):

Unit Test

root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_convolution_ops.py
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (Triggered internally at /opt/conda/conda-bld/pytorch_1699257304556/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2170.)
  return F.conv2d(input, weight, bias, self.stride,
..
----------------------------------------------------------------------
Ran 2 tests in 30.354s

OK
root@luna-prod-78-80gb:/pytorch# python3 test/distributed/_tensor/test_other_ops.py
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank0]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
[rank1]:[W ProcessGroupNCCL.cpp:2170] Warning: 0TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point collectives. (function operator())
...
----------------------------------------------------------------------
Ran 3 tests in 16.343s

OK

ConvNeXt Example

root@luna-prod-78-80gb:/pytorch# python3 torch/distributed/_tensor/examples/convnext_example.py
rank 3, 20 iterations, latency     584.80 ms, forward     102.84 ms, backward     297.80 ms, max reserved    16.34 GiB, max allocated    14.75 GiB
rank 1, 20 iterations, latency     584.64 ms, forward     104.85 ms, backward     297.60 ms, max reserved    16.40 GiB, max allocated    14.74 GiB
rank 0, 20 iterations, latency     584.48 ms, forward     104.64 ms, backward     297.90 ms, max reserved    16.39 GiB, max allocated    14.75 GiB
rank 2, 20 iterations, latency     584.96 ms, forward      93.21 ms, backward     297.95 ms, max reserved    16.40 GiB, max allocated    14.74 GiB

@wanchaol @fduwjj FYI

cc @wanchaol @XilunWu

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113123

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b6a6728 with merge base 7963aaa (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Nov 7, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@Aidyn-A Aidyn-A added topic: new features topic category module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category labels Nov 8, 2023
@Aidyn-A Aidyn-A requested a review from wanchaol November 9, 2023 03:57
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! Convolution op is quite important and also hard to implement in the distributed tensor context, so glad that you already get this working :)

I have a few comments inlined with regard to how the ops (for conv and other ops) are implemented currently. The main thing is that we should see if convolution op can be implemented like other ops, without special casing to do pre/post communication.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! I am thinking if you are willing to spend sometime to rewrite the ops to "strategy" based approach? We recently changed the way to implement ops and going forward we want all ops implemented in the strategy based approach, some examples

it's also fine to land some of these ops using the existing approach you have, but we'll need to refactor it later, we can chat more details about how the new approach works if you want

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep it the current way and refactor it later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to see if we should make these batch send/recv calls be a type of redistribute/resharding, so instead of having a custom convolution op implemented, and special case in torch_dispatch, we should try to see if:

  • for the case of input data exchange, we can make this input send/recvs be a data order permutation (i.e. sth like mesh [0, 1, 2, 3] to mesh [1, 2, 3, 0]).
  • I see duplicate comms happen in the backward op too, we should merge these two into a common function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have merged communications as a function for conv fwd/bwd.

Regarding replacing batch send/recv calls with redistribute/resharding, we can keep the discussion, and I think it will not happen in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to follow what this dist comm is doing, could you add some comments to explain what this batch send/recv is trying to achieve for the conv input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have sent you the slides about implementation details in the slack channel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that those "other ops" are not related to convolution specifically, it would be better if we can separate the other ops enablement with convolution itself with some tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on this? Currently we need to register slice_backward, bernoulli_, nll_loss_forward and nll_loss_backward within other_ops.py to run the tp training example convnext_example.py.

Is it good to keep other_ops.py as it is, and add some unit tests by creating a new file test_other_ops.py within test/distributed/_tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have add unit tests for other ops and showcase the reference log in the description in this PR.

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 9, 2023
Comment on lines 125 to 93
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol shouldn't this resharding happen in _operator_dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have moved it to convolution_backward_handler to avoid this pollution.

Comment on lines 306 to 370
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol what can we do to avoid this customized DTensor op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have moved it to convolution_backward_handler to avoid this pollution.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work! I'm working on backward layer norm support, wish it's not blocking you right now. Left some questions.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing all the lint errors! Please see inlined comments.

I think the other changes looks good, my major comment now is that we shouldn't hijack into the dispatch logic like this in dispatch.py, given that we are doing custom op handing, let's just make convolution now a custom op handler, and handle the custom logic inside tp_conv.py

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, have some small suggestions inlined, please address them before landing. thanks for contributing!

@wanchaol
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag open source release notes: distributed (dtensor) release notes category topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants