-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[dtensor] support convolution ops #113123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113123
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b6a6728 with merge base 7963aaa ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! Convolution op is quite important and also hard to implement in the distributed tensor context, so glad that you already get this working :)
I have a few comments inlined with regard to how the ops (for conv and other ops) are implemented currently. The main thing is that we should see if convolution op can be implemented like other ops, without special casing to do pre/post communication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool! I am thinking if you are willing to spend sometime to rewrite the ops to "strategy" based approach? We recently changed the way to implement ops and going forward we want all ops implemented in the strategy based approach, some examples
- https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/tensor_ops.py#L152
- https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/math_ops.py#L208
it's also fine to land some of these ops using the existing approach you have, but we'll need to refactor it later, we can chat more details about how the new approach works if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to keep it the current way and refactor it later.
torch/distributed/_tensor/tp_conv.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to see if we should make these batch send/recv calls be a type of redistribute/resharding, so instead of having a custom convolution op implemented, and special case in torch_dispatch, we should try to see if:
- for the case of input data exchange, we can make this input send/recvs be a data order permutation (i.e. sth like mesh [0, 1, 2, 3] to mesh [1, 2, 3, 0]).
- I see duplicate comms happen in the backward op too, we should merge these two into a common function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have merged communications as a function for conv fwd/bwd.
Regarding replacing batch send/recv calls with redistribute/resharding, we can keep the discussion, and I think it will not happen in this PR.
torch/distributed/_tensor/tp_conv.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit hard to follow what this dist comm is doing, could you add some comments to explain what this batch send/recv is trying to achieve for the conv input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I have sent you the slides about implementation details in the slack channel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that those "other ops" are not related to convolution specifically, it would be better if we can separate the other ops enablement with convolution itself with some tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate more on this? Currently we need to register slice_backward, bernoulli_, nll_loss_forward and nll_loss_backward within other_ops.py to run the tp training example convnext_example.py.
Is it good to keep other_ops.py as it is, and add some unit tests by creating a new file test_other_ops.py within test/distributed/_tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have add unit tests for other ops and showcase the reference log in the description in this PR.
21f098c to
42e1b89
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol shouldn't this resharding happen in _operator_dispatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have moved it to convolution_backward_handler to avoid this pollution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol what can we do to avoid this customized DTensor op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have moved it to convolution_backward_handler to avoid this pollution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work! I'm working on backward layer norm support, wish it's not blocking you right now. Left some questions.
3506413 to
7ebfc8a
Compare
e49339f to
c3a52cd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing all the lint errors! Please see inlined comments.
I think the other changes looks good, my major comment now is that we shouldn't hijack into the dispatch logic like this in dispatch.py, given that we are doing custom op handing, let's just make convolution now a custom op handler, and handle the custom logic inside tp_conv.py
1d76745 to
eecddbe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, have some small suggestions inlined, please address them before landing. thanks for contributing!
…et better performance
eecddbe to
b6a6728
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR creates a prototype of training convolutional neural networks based on DTensor.
Basically, we shard the activations and replicate the model weights in this prototype. We can scale out to multiple GPUs and reduce the per-GPU memory footprint with this approach, and achieve weak scaling in terms of training performance (i.e., time per iteration).
Reference log (on 2xA100 GPU):
Unit Test
ConvNeXt Example
@wanchaol @fduwjj FYI
cc @wanchaol @XilunWu