-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[TP] Refactor style to make it work with torch.compile #111625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111625
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7718059 with merge base 0617f7f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see inline suggestions.
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for refactor this!
| if is_seq_parallel | ||
| else PrepareModuleInput(input_layouts=Replicate()) | ||
| ) | ||
| no_input_prepare_colwise_style = ColwiseParallel(input_layouts=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm I don't think our input_layouts accept None as it's a Union[Placement, Tuple[Placement]]?
We should probably just do a check inside prepare_input_fn to assure input_layouts matches the passed in DTensor, can be done in follow up PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, let me first merge this and do in a follow-up PR.
| torch.empty_like(tensor) for _ in range(self.world_size) | ||
| ] | ||
| dist.all_gather(gathered_tensors, tensor) | ||
| gathered_tensors = torch.cat(gathered_tensors, dim=0).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can just use functional collective without manually recreate the gathered tensors. Can fix this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, sure. Will do in follow up PR which I will try to change test.
| "mlp_0.net2": RowwiseParallel(), | ||
| "mlp_1.net1": ColwiseParallel(), | ||
| "mlp_1.net2": RowwiseParallel(), | ||
| "mlp_0": module_prepare_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure what this does- does it mean that input is already coming as local tensors (may be sharded or replicated) and this wraps them in DTensors? Or does this actually 'shard' the inputs (Assumes whole inputs first)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input is already coming as local tensors (may be sharded or replicated) and this wraps them in DTensors
Yes. So what I see in xlformer is that it has three/two col-wise linear. Instead of registering hook in each nn.Linear and end up calling all-gather multiple times, we will use module_prepare_input to only register once in the parent module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab it's coming in as local tensors then in this prepareInput we mark it as DTensor and do a allgather (redistribute)
@fduwjj another follow up we should probably do is to remove the default layouts for PrepareModuleInput/Output, and requires user to set it more explicitly so that user know what things is being done under the hood
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
We are refactoring parallel style to solve the following things: 1. To further simplifying code logic to make more readable for users. 2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel. 3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change. 4. Move placements early return check into DTensor since it is by passed by dynamo. 5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style. Pull Request resolved: pytorch#111625 Approved by: https://github.com/wanchaol
Stack from ghstack (oldest at bottom):
We are refactoring parallel style to solve the following things: