KEMBAR78
[DTensor][Bug Fix]Fix 2D DTensor mm with mesh_shape (1, n) or (n, 1) by wz337 · Pull Request #139134 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wz337
Copy link
Contributor

@wz337 wz337 commented Oct 29, 2024

Stack from ghstack (oldest at bottom):

Fixes #138742. In the issue, the matrix multiplication with DTensor failed when the size of one of mesh dimension is 1 when the mesh is > 1D. We are missing tests for covering this corner case where mesh_shape is (n, 1) or (1, n). The DTensor mm op is correct when the 1D mesh is of shape (self.world_size, ) or 2D mesh with none of the mesh_dimension has a size of 1.

In this PR, we fixed the corner case by updating gen_einsum_strategies in _einsum_strategy.py. Specifically, we cannot skip generating mesh_dim_strategies when mesh_dim <= 1, as this is not valid for nD mesh with one of the mesh dimension sizes being 1.

Without the fix, the OpStrategy generated for 2D mesh with mesh_shape of (1,n) or (n,1) is wrong, as the OpStrategy generated is 1D.

all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]]
OpStrategy(all_strategies):::   [(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)[(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)

After the fix, we can see the OpStrategy generated is correct with 2D strategy.

all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]][[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]] 
OpStrategy(all_strategies) = [(RR, RR) -> RR, (RS(1), RS(0)) -> RP, (RS(0), RR) -> RS(0), (RR, RS(1)) -> RS(1), (S(1)R, S(0)R) -> PR, (S(1)S(1), S(0)S(0)) -> PP, (S(1)S(0), S(0)R) -> PS(0), (S(1)R, S(0)S(1)) -> PS(1), (S(0)R, RR) -> S(0)R, (S(0)S(1), RS(0)) -> S(0)P, (S(0)S(0), RR) -> S(0)S(0), (S(0)R, RS(1)) -> S(0)S(1), (RR, S(1)R) -> S(1)R, (RS(1), S(1)S(0)) -> S(1)P, (RS(0), S(1)R) -> S(1)S(0), (RR, S(1)S(1)) -> S(1)S(1)] @ mesh: (4, 1)

As a follow up, we should add more test coverage for DTensor op with 2D mesh and 2D mesh with one of the size of mesh dimension being 1.


cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139134

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e012ff9 with merge base 7c7b2d8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 29, 2024
@wz337 wz337 marked this pull request as draft October 29, 2024 00:27
[ghstack-poisoned]
@wz337 wz337 added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category topic: bug fixes topic category labels Oct 29, 2024
@wz337 wz337 changed the title fix mm with 2D mesh with a mesh_dim size = 1 [DTensor][Bug Fix] Fix 2D DTensor mm with mesh_shape (1, n) or (n, 1) Oct 29, 2024
@wz337 wz337 changed the title [DTensor][Bug Fix] Fix 2D DTensor mm with mesh_shape (1, n) or (n, 1) [DTensor][Bug Fix]Fix 2D DTensor mm with mesh_shape (1, n) or (n, 1) Oct 29, 2024
[ghstack-poisoned]
@wz337 wz337 marked this pull request as ready for review October 29, 2024 01:03
@wz337 wz337 requested review from XilunWu, awgu and tianyu-l October 29, 2024 01:03
@wz337
Copy link
Contributor Author

wz337 commented Oct 29, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/wz337/40/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/139134)

@fegin
Copy link
Contributor

fegin commented Oct 29, 2024

The failing tests looks legit. We should fix them first.

[ghstack-poisoned]
wz337 added a commit that referenced this pull request Oct 29, 2024
ghstack-source-id: 39f3ec9
Pull Request resolved: #139134
@wz337
Copy link
Contributor Author

wz337 commented Oct 29, 2024

The failing tests looks legit. We should fix them first.

Seems there are some issues when using "gloo" backend. So just skip the non gpu tests now and will follow up by looking into DTensorTestbase.

Comment on lines +354 to +356
mesh_0 = init_device_mesh(self.device_type, (self.world_size // 2, 2))
mesh_1 = init_device_mesh(self.device_type, (self.world_size, 1))
mesh_2 = init_device_mesh(self.device_type, (1, self.world_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to put init_device_mesh in the loop?

@wz337
Copy link
Contributor Author

wz337 commented Oct 30, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 30, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
…ytorch#139134)

Fixes pytorch#138742. In the issue, the matrix multiplication with DTensor failed when the size of one of mesh dimension is 1 when the mesh is > 1D. We are missing tests for covering this corner case where mesh_shape is (n, 1) or (1, n). The DTensor mm op is correct when the 1D mesh is of shape (self.world_size, ) or 2D mesh with none of the mesh_dimension has a size of 1.

In this PR, we fixed the corner case by updating `gen_einsum_strategies` in `_einsum_strategy.py`. Specifically, we cannot skip generating `mesh_dim_strategies` when `mesh_dim <= 1`, as this is not valid for nD mesh with one of the mesh dimension sizes being 1.

Without the fix, the OpStrategy generated for 2D mesh with mesh_shape of (1,n) or (n,1) is wrong, as the OpStrategy generated is 1D.

```
all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]]
OpStrategy(all_strategies):::   [(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)[(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)
```

After the fix, we can see the OpStrategy generated is correct with 2D strategy.
```
all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]][[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]]
OpStrategy(all_strategies) = [(RR, RR) -> RR, (RS(1), RS(0)) -> RP, (RS(0), RR) -> RS(0), (RR, RS(1)) -> RS(1), (S(1)R, S(0)R) -> PR, (S(1)S(1), S(0)S(0)) -> PP, (S(1)S(0), S(0)R) -> PS(0), (S(1)R, S(0)S(1)) -> PS(1), (S(0)R, RR) -> S(0)R, (S(0)S(1), RS(0)) -> S(0)P, (S(0)S(0), RR) -> S(0)S(0), (S(0)R, RS(1)) -> S(0)S(1), (RR, S(1)R) -> S(1)R, (RS(1), S(1)S(0)) -> S(1)P, (RS(0), S(1)R) -> S(1)S(0), (RR, S(1)S(1)) -> S(1)S(1)] @ mesh: (4, 1)
```

*******
As a follow up, we should add more test coverage for DTensor op with 2D mesh and 2D mesh with one of the size of mesh dimension being 1.
*******

Pull Request resolved: pytorch#139134
Approved by: https://github.com/fegin
@github-actions github-actions bot deleted the gh/wz337/40/head branch November 30, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants