KEMBAR78
[TP] Refactor style to make it work with torch.compile by fduwjj · Pull Request #111625 · pytorch/pytorch · GitHub
Skip to content

Conversation

@fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Oct 20, 2023

Stack from ghstack (oldest at bottom):

We are refactoring parallel style to solve the following things:

  1. To further simplifying code logic to make more readable for users.
  2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
  3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
  4. Move placements early return check into DTensor since it is by passed by dynamo.
  5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111625

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7718059 with merge base 0617f7f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see inline suggestions.

fduwjj added a commit that referenced this pull request Oct 20, 2023
@fduwjj fduwjj added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Oct 20, 2023
@fduwjj fduwjj requested a review from wanchaol October 20, 2023 05:06
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for refactor this!

if is_seq_parallel
else PrepareModuleInput(input_layouts=Replicate())
)
no_input_prepare_colwise_style = ColwiseParallel(input_layouts=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I don't think our input_layouts accept None as it's a Union[Placement, Tuple[Placement]]?

We should probably just do a check inside prepare_input_fn to assure input_layouts matches the passed in DTensor, can be done in follow up PRs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me first merge this and do in a follow-up PR.

torch.empty_like(tensor) for _ in range(self.world_size)
]
dist.all_gather(gathered_tensors, tensor)
gathered_tensors = torch.cat(gathered_tensors, dim=0).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can just use functional collective without manually recreate the gathered tensors. Can fix this later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, sure. Will do in follow up PR which I will try to change test.

"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
"mlp_0": module_prepare_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure what this does- does it mean that input is already coming as local tensors (may be sharded or replicated) and this wraps them in DTensors? Or does this actually 'shard' the inputs (Assumes whole inputs first)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is already coming as local tensors (may be sharded or replicated) and this wraps them in DTensors

Yes. So what I see in xlformer is that it has three/two col-wise linear. Instead of registering hook in each nn.Linear and end up calling all-gather multiple times, we will use module_prepare_input to only register once in the parent module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab it's coming in as local tensors then in this prepareInput we mark it as DTensor and do a allgather (redistribute)

@fduwjj another follow up we should probably do is to remove the default layouts for PrepareModuleInput/Output, and requires user to set it more explicitly so that user know what things is being done under the hood

@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 20, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/fduwjj/107/head branch October 24, 2023 14:23
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: pytorch#111625
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants