KEMBAR78
added zbv_algorithm by haocizhang · Pull Request #138444 · pytorch/pytorch · GitHub
Skip to content

Conversation

@haocizhang
Copy link
Contributor

@haocizhang haocizhang commented Oct 21, 2024

Added ZBV algorithm to pp schedules. See https://arxiv.org/pdf/2401.10241 for details.

Tested schedule using python test_schedules.py

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138444

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit ecbcbb4 with merge base e7ec294 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 21, 2024
@haocizhang
Copy link
Contributor Author

@pytorchbot label "release notes: distributed (pipeline)"

@pytorch-bot pytorch-bot bot added the release notes: distributed (pipeline) release notes category label Oct 22, 2024
@haocizhang
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 22, 2024
@haocizhang haocizhang force-pushed the zbv_algo branch 2 times, most recently from c24df0c to 4f576b9 Compare October 22, 2024 04:00
count = []
for i in range(pipeline_parallel_size):
count.append([0] * 6)
fbw_mem = [39, -7, -32]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment? what are these numbers

def get_compute_schedule(pipeline_parallel_size, num_microbatches):
n_node = 6 * pipeline_parallel_size * num_microbatches

def get_id(cat, chunk, rank, micro):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment what 'id' is for?

compute_schedules = {}

def get_compute_schedule(pipeline_parallel_size, num_microbatches):
n_node = 6 * pipeline_parallel_size * num_microbatches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment what is 6?

schedule[i] = []
stage_str = [" " * i for i in range(pipeline_parallel_size)]
approved_bubble = [-1] * pipeline_parallel_size
max_approved_bubble = max(approved_bubble)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is max_approved_bubble always -1?

approved_bubble = [-1] * pipeline_parallel_size
max_approved_bubble = max(approved_bubble)

def get_max_rank_bubble(rank=-1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im pretty confused by this helper. it looks like it should do something complex, but it seems like all its inputs are static and it should return 0 all the time or something.

_tmp = _no_bubble = cur_time[rank] + 1
_cnt = count[rank][cat * 2 + chunk]
stage_str[rank] += (
"FfBbWw"[cat * 2 + chunk]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

end_time[_id] = _tmp
cur_time[rank] = _tmp
mem[rank] += fbw_mem[cat]
# noinspection PyTypeChecker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually should not disable type checker, we should add the mypy hints everywhere and make sure its correct. helps catch some bugs.

_, chunk_, _ = pending_w[rank].popleft()
put(2, chunk_, rank)

def put(cat, chunk, rank, assert_cnt=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would help to have description of what 'cat' means, especially. maybe chunk is intuitive but idk yet.

@haocizhang haocizhang force-pushed the zbv_algo branch 2 times, most recently from d43b4aa to f00c129 Compare November 11, 2024 19:54
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_zbv_schedule() is confusing to me, I need some time to read over the paper and digest this implementation more

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include a short description of the differences between ZBV and the existing InterleavedZeroBubble schedule?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the differences between this fn and

def _add_bubbles_to_actions(self, num_stages_global):
, could you comment it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the duplicated function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang IIUC this function is only needed so that the zbv schedule can be compatible with the preexisting PP runtime in PipelineScheduleMulti. If we consolidate on the newer PipelineScheduleRuntime class, we don't have to add bubbles anymore and we can simplify our schedule IR generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haocizhang did you rebase? A couple weeks ago I landed some PRs to support dW/dI runner and i also renamed the IR, I think its BACKWARD_WEIGHT, BACKWARD_INPUT, FULL_BACKWARD now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah realized I didn't rebase :) Rebased and updated the PR

I: 1,
W: 2,
}
chunk_0 = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does zbv hardcode that there are exactly 2 model-chunks per PP rank? should we make that more of an assertion if so?

@H-Huang
Copy link
Member

H-Huang commented Nov 12, 2024

cc @ufotalent who is one the zero bubble paper authors.

This PR implements the ZBV variant of zero bubble:
image

Is there a simpler heuristic which we can use to guide the ordering of F-B-W for each device, regardless of # of ranks and # of stages?

category = category_map[op]
# Number of ops (F/B/W) with the same (action, chunk) on current rank
_op_count = ops_count[rank][op][chunk]
if chunk == chunk_1 and op in (F, I):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe these 2 ifs would make for a nice helper function whose name would further clarify their purpose. IIUC the idea here is to figure out what the earliest time this op/chunk can run on the current rank, given the time that its dependency is scheduled on another rank?

"get_earliest_time_based_on_dependency(op, chunk, rank)"

num_chunks = 2
n_node = len(category_map) * num_chunks * pipeline_parallel_size * num_microbatches

def get_id(op, chunk, rank, microbatch_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc the purpose of this helper is to hash the op uniquely? i think you could just directly key a dict off the op itself and achieve the same thing?

# For BACKWARD and WEIGHT operation, we will schedule chunk 1 before 0 so
# inversing the order before adding to the schedule
temp_chunk = chunk if op == F else 1 - chunk
schedule[rank].append((_op_count, op, temp_chunk))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you directly construct an _Action() object here as you append, instead of creating a similar but not the same representation? Otoh I understand that for the logic here it is convenient to refer to local chunk ID (0,1) rather than global stage_id which combines local chunk_id with rank. IIUC op_count is the same as microbatch_id?


fbw_mem = [3, -1, -2]
max_mem = 3 * (pipeline_parallel_size * 2)
end_time = [-1] * n_node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as i read through the code below, i realize i'm confused between cur_time and end_time. Maybe it will become more clear..

put(F, chunk_1, cur_rank)

iter_chunk_ = 0
# Ensure forward operation synchronization across pipeline stages
Copy link
Contributor

@wconstab wconstab Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define 'synchronization'? at first i thought this was aiming to ensure the same number of microbatches per chunk, but that's not quite what the logic below does, it seems more like ensuring the same number of actions per rank but not necessarily same number of chunk0 vs chunk1

edit: ok i was confused bc i only looked at the 'while' logic. but the if condition in the for loop looks like it does ensure all F's are scheduled for both chunks.

+ ops_count[current_rank][F][chunk_1]
< ops_count[previous_rank][F][chunk_0]
+ ops_count[previous_rank][F][chunk_1]
or ops_count[current_rank][F][chunk_1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic after the or is confusing to me. Why would a 'previous rank' ever have more chunk1's scheduled than the current rank?

for rank in range(pipeline_parallel_size):
chunk_0_ops = ops_count[rank][I][chunk_0]
chunk_1_ops = ops_count[rank][I][chunk_1]
if chunk_1_ops >= chunk_0_ops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk1_ops == chunk0_ops doesn't make sense to me here. Only for the last rank would chunk1 op directly unblock chunk0 op. For other ranks, should the logic for chunk0 ready be depenent on the 'dependency_id' condition below?

# Schedule backward operations for each rank
for rank, chunk in scheduled_ranks:
dependency_id = -1
if chunk == chunk_1 and rank < pipeline_parallel_size - 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i must be confusing myself, but chunk1 on rank0 would be the first chunk1 B to run wouldn't it? so then chunk1 on rank1 would have a dependency on chunk1 on rank0? so dependency_id should have rank - 1 instead of rank + 1? And vice-versa for the chunk0 logic below?

@QPHutu
Copy link

QPHutu commented Nov 13, 2024

cc @ufotalent who is one the zero bubble paper authors.

This PR implements the ZBV variant of zero bubble: image

Is there a simpler heuristic which we can use to guide the ordering of F-B-W for each device, regardless of # of ranks and # of stages?

@H-Huang IIUC, you only want a handcrafted ZB-V schedule here. If so, you don't need any heuristic/greedy methods, a deterministic rule/pattern can be used to directly generate ZB-V schedule.

In the zero-bubble paper, we implemented a complicated greedy method based on profiled $$T_F, T_B, T_W$$ to minimize the bubble caused by the inequality of these running times. However, in this code, $$T_F, T_B, T_W$$ are hardcoded as 1, which means you don't need to implement our greedy method. A specific pattern should work for you. Please refer to this handcrafted ZB-V implementation.

Additionally, we have another version of implementation on ZB-V in another paper NeuIPS 2024 and code, which is conceptually simpler than our previous greedy method. If you also want an adaptive version given running times/memories as inputs, maybe we can help to simplify the implementation (current implementation also supports other schedules like V-Half).

QPHutu added a commit to sail-sg/zero-bubble-pipeline-parallelism that referenced this pull request Nov 13, 2024
To support ZB-V in native pytorch

pytorch/pytorch#138444
QPHutu added a commit to sail-sg/zero-bubble-pipeline-parallelism that referenced this pull request Nov 13, 2024
QPHutu added a commit to sail-sg/zero-bubble-pipeline-parallelism that referenced this pull request Nov 14, 2024
H-Huang added a commit that referenced this pull request Dec 10, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444

cc the original authors: QPHutu ufotalent #138444 (comment)

cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Dec 10, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444

cc the original authors: QPHutu ufotalent #138444 (comment)

cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Dec 11, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for #138444

cc the original authors: @QPHutu @ufotalent #138444 (comment)

Pull Request resolved: #142084
Approved by: https://github.com/kwen2501
pytorchmergebot pushed a commit to mori360/pytorch that referenced this pull request Dec 11, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444

cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment)

Pull Request resolved: pytorch#142084
Approved by: https://github.com/kwen2501
mori360 pushed a commit to mori360/pytorch that referenced this pull request Dec 11, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444

cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment)

Pull Request resolved: pytorch#142084
Approved by: https://github.com/kwen2501
bluenote10 pushed a commit to bluenote10/pytorch that referenced this pull request Dec 14, 2024
Adds ZBV schedule which is explained in https://arxiv.org/pdf/2401.10241, Section 6. Tested it works under the new PipelineScheduleRuntime by fixing a small bug in handling V-shaped schedules. This PR is a replacement for pytorch#138444

cc the original authors: @QPHutu @ufotalent pytorch#138444 (comment)

Pull Request resolved: pytorch#142084
Approved by: https://github.com/kwen2501
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 12, 2025
@github-actions github-actions bot closed this Feb 11, 2025
@github-actions github-actions bot deleted the zbv_algo branch March 14, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category Stale topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants