-
Notifications
You must be signed in to change notification settings - Fork 25.7k
added zbv_algorithm #138444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added zbv_algorithm #138444
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138444
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit ecbcbb4 with merge base e7ec294 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: distributed (pipeline)" |
|
@pytorchbot label "topic: not user facing" |
c24df0c to
4f576b9
Compare
| count = [] | ||
| for i in range(pipeline_parallel_size): | ||
| count.append([0] * 6) | ||
| fbw_mem = [39, -7, -32] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment? what are these numbers
| def get_compute_schedule(pipeline_parallel_size, num_microbatches): | ||
| n_node = 6 * pipeline_parallel_size * num_microbatches | ||
|
|
||
| def get_id(cat, chunk, rank, micro): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment what 'id' is for?
| compute_schedules = {} | ||
|
|
||
| def get_compute_schedule(pipeline_parallel_size, num_microbatches): | ||
| n_node = 6 * pipeline_parallel_size * num_microbatches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment what is 6?
| schedule[i] = [] | ||
| stage_str = [" " * i for i in range(pipeline_parallel_size)] | ||
| approved_bubble = [-1] * pipeline_parallel_size | ||
| max_approved_bubble = max(approved_bubble) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is max_approved_bubble always -1?
| approved_bubble = [-1] * pipeline_parallel_size | ||
| max_approved_bubble = max(approved_bubble) | ||
|
|
||
| def get_max_rank_bubble(rank=-1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im pretty confused by this helper. it looks like it should do something complex, but it seems like all its inputs are static and it should return 0 all the time or something.
| _tmp = _no_bubble = cur_time[rank] + 1 | ||
| _cnt = count[rank][cat * 2 + chunk] | ||
| stage_str[rank] += ( | ||
| "FfBbWw"[cat * 2 + chunk] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this?
| end_time[_id] = _tmp | ||
| cur_time[rank] = _tmp | ||
| mem[rank] += fbw_mem[cat] | ||
| # noinspection PyTypeChecker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we actually should not disable type checker, we should add the mypy hints everywhere and make sure its correct. helps catch some bugs.
| _, chunk_, _ = pending_w[rank].popleft() | ||
| put(2, chunk_, rank) | ||
|
|
||
| def put(cat, chunk, rank, assert_cnt=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would help to have description of what 'cat' means, especially. maybe chunk is intuitive but idk yet.
d43b4aa to
f00c129
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_zbv_schedule() is confusing to me, I need some time to read over the paper and digest this implementation more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you include a short description of the differences between ZBV and the existing InterleavedZeroBubble schedule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the differences between this fn and
pytorch/torch/distributed/pipelining/schedules.py
Line 2100 in 780b28f
| def _add_bubbles_to_actions(self, num_stages_global): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the duplicated function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@H-Huang IIUC this function is only needed so that the zbv schedule can be compatible with the preexisting PP runtime in PipelineScheduleMulti. If we consolidate on the newer PipelineScheduleRuntime class, we don't have to add bubbles anymore and we can simplify our schedule IR generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haocizhang did you rebase? A couple weeks ago I landed some PRs to support dW/dI runner and i also renamed the IR, I think its BACKWARD_WEIGHT, BACKWARD_INPUT, FULL_BACKWARD now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah realized I didn't rebase :) Rebased and updated the PR
| I: 1, | ||
| W: 2, | ||
| } | ||
| chunk_0 = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does zbv hardcode that there are exactly 2 model-chunks per PP rank? should we make that more of an assertion if so?
|
cc @ufotalent who is one the zero bubble paper authors. This PR implements the ZBV variant of zero bubble: Is there a simpler heuristic which we can use to guide the ordering of |
| category = category_map[op] | ||
| # Number of ops (F/B/W) with the same (action, chunk) on current rank | ||
| _op_count = ops_count[rank][op][chunk] | ||
| if chunk == chunk_1 and op in (F, I): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe these 2 ifs would make for a nice helper function whose name would further clarify their purpose. IIUC the idea here is to figure out what the earliest time this op/chunk can run on the current rank, given the time that its dependency is scheduled on another rank?
"get_earliest_time_based_on_dependency(op, chunk, rank)"
| num_chunks = 2 | ||
| n_node = len(category_map) * num_chunks * pipeline_parallel_size * num_microbatches | ||
|
|
||
| def get_id(op, chunk, rank, microbatch_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc the purpose of this helper is to hash the op uniquely? i think you could just directly key a dict off the op itself and achieve the same thing?
| # For BACKWARD and WEIGHT operation, we will schedule chunk 1 before 0 so | ||
| # inversing the order before adding to the schedule | ||
| temp_chunk = chunk if op == F else 1 - chunk | ||
| schedule[rank].append((_op_count, op, temp_chunk)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you directly construct an _Action() object here as you append, instead of creating a similar but not the same representation? Otoh I understand that for the logic here it is convenient to refer to local chunk ID (0,1) rather than global stage_id which combines local chunk_id with rank. IIUC op_count is the same as microbatch_id?
|
|
||
| fbw_mem = [3, -1, -2] | ||
| max_mem = 3 * (pipeline_parallel_size * 2) | ||
| end_time = [-1] * n_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as i read through the code below, i realize i'm confused between cur_time and end_time. Maybe it will become more clear..
| put(F, chunk_1, cur_rank) | ||
|
|
||
| iter_chunk_ = 0 | ||
| # Ensure forward operation synchronization across pipeline stages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define 'synchronization'? at first i thought this was aiming to ensure the same number of microbatches per chunk, but that's not quite what the logic below does, it seems more like ensuring the same number of actions per rank but not necessarily same number of chunk0 vs chunk1
edit: ok i was confused bc i only looked at the 'while' logic. but the if condition in the for loop looks like it does ensure all F's are scheduled for both chunks.
| + ops_count[current_rank][F][chunk_1] | ||
| < ops_count[previous_rank][F][chunk_0] | ||
| + ops_count[previous_rank][F][chunk_1] | ||
| or ops_count[current_rank][F][chunk_1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the logic after the or is confusing to me. Why would a 'previous rank' ever have more chunk1's scheduled than the current rank?
| for rank in range(pipeline_parallel_size): | ||
| chunk_0_ops = ops_count[rank][I][chunk_0] | ||
| chunk_1_ops = ops_count[rank][I][chunk_1] | ||
| if chunk_1_ops >= chunk_0_ops: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chunk1_ops == chunk0_ops doesn't make sense to me here. Only for the last rank would chunk1 op directly unblock chunk0 op. For other ranks, should the logic for chunk0 ready be depenent on the 'dependency_id' condition below?
| # Schedule backward operations for each rank | ||
| for rank, chunk in scheduled_ranks: | ||
| dependency_id = -1 | ||
| if chunk == chunk_1 and rank < pipeline_parallel_size - 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i must be confusing myself, but chunk1 on rank0 would be the first chunk1 B to run wouldn't it? so then chunk1 on rank1 would have a dependency on chunk1 on rank0? so dependency_id should have rank - 1 instead of rank + 1? And vice-versa for the chunk0 logic below?
@H-Huang IIUC, you only want a handcrafted ZB-V schedule here. If so, you don't need any heuristic/greedy methods, a deterministic rule/pattern can be used to directly generate ZB-V schedule. In the zero-bubble paper, we implemented a complicated greedy method based on profiled Additionally, we have another version of implementation on ZB-V in another paper NeuIPS 2024 and code, which is conceptually simpler than our previous greedy method. If you also want an adaptive version given running times/memories as inputs, maybe we can help to simplify the implementation (current implementation also supports other schedules like V-Half). |
To support ZB-V in native pytorch pytorch/pytorch#138444
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444 cc the original authors: QPHutu ufotalent #138444 (comment) cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444 cc the original authors: QPHutu ufotalent #138444 (comment) cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444 cc the original authors: @QPHutu @ufotalent #138444 (comment) Pull Request resolved: #142084 Approved by: https://github.com/kwen2501
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444 cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment) Pull Request resolved: pytorch#142084 Approved by: https://github.com/kwen2501
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444 cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment) Pull Request resolved: pytorch#142084 Approved by: https://github.com/kwen2501
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444 cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment) Pull Request resolved: pytorch#142084 Approved by: https://github.com/kwen2501
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |


Added ZBV algorithm to pp schedules. See https://arxiv.org/pdf/2401.10241 for details.
Tested schedule using python test_schedules.py
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o