KEMBAR78
[PP] Add DualPipeV schedule by H-Huang · Pull Request #159591 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Jul 31, 2025

Stack from ghstack (oldest at bottom):

Added the DualPipeV schedule according to http://github.com/deepseek-ai/DualPipe/blob/main/dualpipe/dualpipev.py#L11

image

This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs.

Tests:

python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules
python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules

Also tested in TorchTitan and is running.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159591

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 Cancelled Jobs, 1 Unrelated Failure

As of commit a82d4a3 with merge base 9d37c96 (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 31, 2025
H-Huang added a commit that referenced this pull request Jul 31, 2025
ghstack-source-id: 273b3c7
Pull Request resolved: #159591
@H-Huang H-Huang changed the title Add DualPipeV schedule [PP] Add DualPipeV schedule Jul 31, 2025
@H-Huang H-Huang requested review from kwen2501 and wconstab July 31, 2025 21:36
@H-Huang H-Huang added the release notes: distributed (pipeline) release notes category label Jul 31, 2025
Added the DualPipeV schedule according to https://github.com/deepseek-ai/DualPipe

<img width="3168" height="486" alt="image" src="https://github.com/user-attachments/assets/5c2d61cc-f7d9-4af6-9542-cfb638f2567e" />

This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs.

Tests:
```sh
python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules
python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules
```

Also tested in TorchTitan and is running.

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: c14fcb4
Pull Request resolved: #159591
H-Huang added a commit that referenced this pull request Aug 1, 2025
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see #159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

https://github.com/pytorch/pytorch/blob/814629043a0c31441bc3749204c97f1e24fa3462/torch/distributed/pipelining/schedules.py#L1205-L1216

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Aug 1, 2025
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see #159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

https://github.com/pytorch/pytorch/blob/814629043a0c31441bc3749204c97f1e24fa3462/torch/distributed/pipelining/schedules.py#L1205-L1216

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@H-Huang H-Huang added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 1, 2025
pytorchmergebot pushed a commit that referenced this pull request Aug 1, 2025
Some changes to validation code and visualizer to support a new computation type that will be used in DualPipeV (see #159591)

The IR looks like:

```
[0F0, 0F1, 0F2, 0F3, 0F4, 0F5, 0F6, 7F0, 7I0, 7W0, 7F1, 7I1, 7W1, 7F2, 7I2, 7W2, 7F3, (0F7;7B3)OVERLAP_F_B, (7F4;0B0)OVERLAP_F_B, (0F8;7B4)OVERLAP_F_B, (7F5;0B1)OVERLAP_F_B, (0F9;7B5)OVERLAP_F_B, (7F6;0B2)OVERLAP_F_B, 7B6, (7F7;0B3)OVERLAP_F_B, 7B7, (7F8;0B4)OVERLAP_F_B, 7B8, (7F9;0B5)OVERLAP_F_B, 7B9, 0I6, 0W6, 0I7, 0W7, 0I8, 0W8, 0I9, 0W9]
[1F0, 1F1, 1F2, 1F3, 1F4, 6F0, 1F5, 6F1, 6I0, 6W0, 6F2, 6I1, 6W1, 6F3, (1F6;6B2)OVERLAP_F_B, (6F4;1B0)OVERLAP_F_B, (1F7;6B3)OVERLAP_F_B, (6F5;1B1)OVERLAP_F_B, (1F8;6B4)OVERLAP_F_B, (6F6;1B2)OVERLAP_F_B, (1F9;6B5)OVERLAP_F_B, (6F7;1B3)OVERLAP_F_B, 6B6, (6F8;1B4)OVERLAP_F_B, 6B7, (6F9;1B5)OVERLAP_F_B, 6B8, 1B6, 6I9, 1I7, 6W9, 1I8, 1W7, 1I9, 1W8, 1W9]
[2F0, 2F1, 2F2, 5F0, 2F3, 5F1, 2F4, 5F2, 5I0, 5W0, 5F3, (2F5;5B1)OVERLAP_F_B, (5F4;2B0)OVERLAP_F_B, (2F6;5B2)OVERLAP_F_B, (5F5;2B1)OVERLAP_F_B, (2F7;5B3)OVERLAP_F_B, (5F6;2B2)OVERLAP_F_B, (2F8;5B4)OVERLAP_F_B, (5F7;2B3)OVERLAP_F_B, (2F9;5B5)OVERLAP_F_B, (5F8;2B4)OVERLAP_F_B, 5B6, (5F9;2B5)OVERLAP_F_B, 5B7, 2B6, 5B8, 2I7, 5I9, 2I8, 2W7, 2I9, 5W9, 2W8, 2W9]
[3F0, 4F0, 3F1, 4F1, 3F2, 4F2, 3F3, 4F3, 3F4, 4B0, (4F4;3B0)OVERLAP_F_B, (3F5;4B1)OVERLAP_F_B, (4F5;3B1)OVERLAP_F_B, (3F6;4B2)OVERLAP_F_B, (4F6;3B2)OVERLAP_F_B, (3F7;4B3)OVERLAP_F_B, (4F7;3B3)OVERLAP_F_B, (3F8;4B4)OVERLAP_F_B, (4F8;3B4)OVERLAP_F_B, (3F9;4B5)OVERLAP_F_B, (4F9;3B5)OVERLAP_F_B, 4B6, 3B6, 4B7, 3B7, 4I8, 3I8, 4I9, 3I9, 4W8, 3W8, 4W9, 3W9]
```

In this PR, the schedule execution will just treat the OVERLAP_F_B as two separate operations of F and B (so there is no actual overlap). The next step is to allow users to create a custom function to plug in what this operation does.

https://github.com/pytorch/pytorch/blob/814629043a0c31441bc3749204c97f1e24fa3462/torch/distributed/pipelining/schedules.py#L1205-L1216

Pull Request resolved: #158978
Approved by: https://github.com/wconstab
self.pipeline_order[rank] = rank_ops

# Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
self._load_actions(self.pipeline_order)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regretting that this function name doesn't make it clear that it's running all the passes to insert comms.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can update it in a follow up PR! I also was making a dumb mistake of calling _load_action twice which doesn't error out but leads to numerics issues, so im going to add validation to check for this

weight_key = (actual_stage_index, _ComputationType.BACKWARD_WEIGHT)
counters[weight_key] = counters.get(weight_key, 0) + 1

# Step 1: F0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally, is there some matching between the 6 steps and the rank-based formulas you used here, and the paper/code? it isn't that obvious how this compares to the logic in dualpipev

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for dualpipev is here (https://github.com/deepseek-ai/DualPipe/blob/3da1bbea53606543d7f5f232338fc58096db30e3/dualpipe/dualpipev.py#L331-L396). So this schedule copies that (minus their communication operations, which we put in ourselves).

The overlapped_f_b in the dualpipev is a dummy implementation using an MLP. The DeepSeek team does not provide a concrete implementation for what they described in the paper, but they leverage the tech built in DeepEP for the dispatch + MOE + combine example (https://github.com/deepseek-ai/DeepEP/blob/main/README.md#example-use-in-model-training-or-inference-prefilling)

@wconstab
Copy link
Contributor

should the picture in your PR desc match this picture exactly (from DualPipe?
image
at first i thought there were differences but i realized i was just confused by the notation. I think it might be a perfect match, but actually i'd like to see a unit test that asserts your code generates this picture for the 4-rank 10-mb config.

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Aug 12, 2025
ghstack-source-id: b73ada4
Pull Request resolved: #159591
[ghstack-poisoned]
[ghstack-poisoned]
@H-Huang
Copy link
Member Author

H-Huang commented Aug 14, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@H-Huang
Copy link
Member Author

H-Huang commented Aug 14, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@H-Huang
Copy link
Member Author

H-Huang commented Aug 14, 2025

@pytorchbot merge -i

pytorchmergebot pushed a commit that referenced this pull request Aug 14, 2025
Rename method and add validation
Pull Request resolved: #160558
Approved by: https://github.com/wconstab
ghstack dependencies: #159591
pytorchmergebot pushed a commit that referenced this pull request Aug 14, 2025
Update schedule tests to use `world_size=4`, changes needed:
- Move some tests that require world_size=2 to new class
- Move helper methods from class level to function level
- Update some initialization to pass assert since gradients were super small.

Pull Request resolved: #160559
Approved by: https://github.com/wconstab
ghstack dependencies: #159591, #160558
H-Huang added a commit to pytorch/torchtitan that referenced this pull request Aug 15, 2025
DualPipeV was added to pt-core
(pytorch/pytorch#159591) so just adding code to
support it in titan

To use, in .toml file set:
```
pipeline_parallel_schedule = "DualPipeV"
```

Ideally we don't have this if-statement check, so as a future BE task I
can look into removing it
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Added the DualPipeV schedule according to http://github.com/deepseek-ai/DualPipe/blob/main/dualpipe/dualpipev.py#L11

<img width="3633" height="486" alt="image" src="https://github.com/user-attachments/assets/4e843bb9-87cd-4d11-936c-7dfe8ee12f16" />

This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs.

Tests:
```sh
python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules
python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules
```

Also tested in TorchTitan and is running.

Pull Request resolved: pytorch#159591
Approved by: https://github.com/wconstab
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Rename method and add validation
Pull Request resolved: pytorch#160558
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#159591
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Update schedule tests to use `world_size=4`, changes needed:
- Move some tests that require world_size=2 to new class
- Move helper methods from class level to function level
- Update some initialization to pass assert since gradients were super small.

Pull Request resolved: pytorch#160559
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#159591, pytorch#160558
@github-actions github-actions bot deleted the gh/H-Huang/202/head branch September 14, 2025 02:13
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Added the DualPipeV schedule according to http://github.com/deepseek-ai/DualPipe/blob/main/dualpipe/dualpipev.py#L11

<img width="3633" height="486" alt="image" src="https://github.com/user-attachments/assets/4e843bb9-87cd-4d11-936c-7dfe8ee12f16" />

This schedule doesn't perform the actual "overlap" during execution, but provides the scaffolding and schedule definition we need to run it E2E in torchtitan. Supporting the overlapped operation will be worked on in following PRs.

Tests:
```sh
python test/distributed/pipelining/test_schedule_multiproc.py -k test_v_shape_schedules
python test/distributed/pipelining/test_schedule.py -k test_pipeline_order_for_v_schedules
```

Also tested in TorchTitan and is running.

Pull Request resolved: pytorch#159591
Approved by: https://github.com/wconstab
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Rename method and add validation
Pull Request resolved: pytorch#160558
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#159591
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Update schedule tests to use `world_size=4`, changes needed:
- Move some tests that require world_size=2 to new class
- Move helper methods from class level to function level
- Update some initialization to pass assert since gradients were super small.

Pull Request resolved: pytorch#160559
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#159591, pytorch#160558
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants