KEMBAR78
[Pipelining] Support separate dI / dW and V-schedules by wconstab · Pull Request #131762 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jul 25, 2024

Stack from ghstack (oldest at bottom):

Separate dI / dW:

PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.

Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles. In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.

V-schedules:

V-schedules have a special case where the last rank has 2 adjacent
stages.

E.g. if rank3 had stage 3 and stage 4, then we should implement direct
transfer of stage3 outputs to stage4 inputs without a
send/recv.

In the schedling logic, we also must allow scheduling the
stage 4 forward after running stage 3 forward, without expecting a stage
4 RECV_F

In the runtime, we pass activations between adjacent stages without
using SEND/RECV ops since the stages are on the same rank/process. We
add new APIs to PipelineStage abstraction for passing the activations
both during forward and backward. Currently the implementation directly
modifies the 'recv buffers' the stage is managing, so the
forward/backwrad execution logic does not need to know the difference.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131762

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dda57c0 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 25, 2024
TODO- consider redoing the IR names so that BW is just 'B' and B, W, are
Bx and Bw or something

ghstack-source-id: 4d23ad4
Pull Request resolved: #131762
[ghstack-poisoned]
This was referenced Aug 1, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 1, 2024
TODO- consider redoing the IR names so that BW is just 'B' and B, W, are
Bx and Bw or something

ghstack-source-id: 22f0592
Pull Request resolved: #131762
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 1, 2024
TODO- consider redoing the IR names so that BW is just 'B' and B, W, are
Bx and Bw or something

ghstack-source-id: 2367da3
Pull Request resolved: #131762
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 1, 2024
TODO- consider redoing the IR names so that BW is just 'B' and B, W, are
Bx and Bw or something

ghstack-source-id: be6b293
Pull Request resolved: #131762
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Aug 16, 2024
TODO- consider redoing the IR names so that BW is just 'B' and B, W, are
Bx and Bw or something

ghstack-source-id: 30408fa
Pull Request resolved: #131762
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 15, 2024
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab wconstab changed the title WIP batch B,W into BW [Pipelining] Merge adjacent B,W into BW Oct 16, 2024
@wconstab wconstab changed the title [Pipelining] Merge adjacent B,W into BW [Pipelining] Runtime support and optimization for separate dI / dW Oct 25, 2024
self.use_full_backward = False

# Go through two microbatches
# TODO(whc) unify the semantics of the IR for old runtime with new runtime.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in later PR in this stack

ops.extend(stage.get_fwd_send_ops(mb_index))
elif computation_type == _ComputationType.BACKWARD:

# TODO(whc) for now i'm going with the hopefully backward-compatible position that legacy IR with
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in later PR in this stack

return True
return False
elif action.computation_type == B:
elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should change this to just 'FULL_BACKWARD' to be consistent with the rest of this PR and then add BACKWARD_INPUT back in in the later PR where I fix other inconsistencies.

[ghstack-poisoned]
@wconstab wconstab requested a review from H-Huang October 30, 2024 19:44
@@ -0,0 +1,2 @@
0F0,0F1,2F0,,2F1,2I0,2W0,0F2,2I1,2W1,0F3,0I0,0W0,2F2,0I1,0W1,2F3,2I2,2W2,0F4,2I3,2W3,0F5,0I2,0W2,2F4,0I3,0W3,2F5,2I4,2W4,0F6,2I5,2W5,0F7,0I4,0W4,2F6,0I5,0W5,2F7,2I6,2W6,2I7,2W7,0I6,0W6,0I7,0W7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, if you load the compute csv without the comms csv will it error? Or will it automatically determine the comms for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

within test_csv, it is explicit: we load the compute-only one, run the lowering passes, then compare the output of that with the saved comms one.

for real users, it is API'd somewhat: load_csv in PipelineScheduleRuntime accepts a kwarg for whether its a comms or compute csv. If its a compute one, then it will run add_send_recv.

I think before we more widely roll this out, we should better define the api around the lowering passes. Probably a function for 'lowering' the schedule and some config flags for any optional passes. For now its kind of manual: load_csv just hardcodes which passes to run.

[ghstack-poisoned]
[ghstack-poisoned]
# as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph.
# TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does
# detach have any affect on that?
info.buffer = tensor.detach().requires_grad_(True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang i should remove this TODO. It seems to pass the gradient tests but I wonder if you have any more insight, if I am doing the best thing here.

assert not self.is_first, "can't get bwd output if this stage is first"

self._check_chunk_id(mb_index)
# TODO(whc) we should be indexing mb_index into self.grads_input, but it appears we are only storing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made an issue
#139404

Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv.
Does not detach or set '_requires_grad'.
"""
# TODO(whc) discrepancy between list/tuple type here. need to clean up
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this TODO should be removed as we decided that it is expected that users could choose tuple or tensor as return value. the 'normalize' function should be fixing this for us now? (but we aren't calling it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#139405

made an issue. looks like we currently don't know what type this will be and aren't normalizing it. lets fix in another PR.

torch.export tracing, compiled models may also return a list instead of a Tuple, which we will normalize back to a
tuple for consistency.
TODO: should we be stricter about asserting that stage modules (intermediate and output) all return only Tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang wdyt, should we be asserting this? if a stage returned a non-tensor, we'd fail on send/recv right?

[ghstack-poisoned]
[ghstack-poisoned]
stage_idx,
n_stages,
device,
# TODO(whc) shape inference shouldn't have needed to run communications in this 1-rank, 2-stage scenario,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this, it is fixed now. probably fixed by @H-Huang's PR to fix single-stage schedule usage.

[ghstack-poisoned]
@wconstab wconstab changed the title [Pipelining] Runtime support and optimization for separate dI / dW [Pipelining] Support separate dI / dW and V-schedules Oct 31, 2024
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2024
Used in both simulator and add_send_recv pass, the ready_to_schedule
logic works by looking at all the previously scheduled ops on a rank to
see if any of them 'unblocks' the current op to be scheduled.  For example,
to schedule a FORWARD op, a previous RECV_F op is needed, unless this is
stage 0 or there is a previous stage on the same rank that ran FORWARD
already.

The old implementation iteratively compared the candidate op to the
previous ops.  The new implementation uses set lookups to reduce
complexity.  It also maintains the set of previous ops as ops are
scheduled rather than constructing a set on demand.

I did not save benchmark results, but this results in a 10-100x speedup
which is most noticeable for unit tests with artificially huge schedule
IR, the largest of which took longer than 20m before (I never let it
finish) but now takes less than 14s.  Most schedules take less than
10ms.

Pull Request resolved: #138924
Approved by: https://github.com/H-Huang
ghstack dependencies: #138928, #131762
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
### Separate dI / dW:

PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.

Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles.  In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.

### V-schedules:

V-schedules have a special case where the last rank has 2 adjacent
stages.

E.g. if rank3 had stage 3 and stage 4, then we should implement direct
transfer of stage3 outputs to stage4 inputs without a
send/recv.

In the schedling logic, we also must allow scheduling the
stage 4 forward after running stage 3 forward, without expecting a stage
4 RECV_F

In the runtime, we pass activations between adjacent stages without
using SEND/RECV ops since the stages are on the same rank/process.  We
add new APIs to PipelineStage abstraction for passing the activations
both during forward and backward.  Currently the implementation directly
modifies the 'recv buffers' the stage is managing, so the
forward/backwrad execution logic does not need to know the difference.
Pull Request resolved: pytorch#131762
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#138928
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Used in both simulator and add_send_recv pass, the ready_to_schedule
logic works by looking at all the previously scheduled ops on a rank to
see if any of them 'unblocks' the current op to be scheduled.  For example,
to schedule a FORWARD op, a previous RECV_F op is needed, unless this is
stage 0 or there is a previous stage on the same rank that ran FORWARD
already.

The old implementation iteratively compared the candidate op to the
previous ops.  The new implementation uses set lookups to reduce
complexity.  It also maintains the set of previous ops as ops are
scheduled rather than constructing a set on demand.

I did not save benchmark results, but this results in a 10-100x speedup
which is most noticeable for unit tests with artificially huge schedule
IR, the largest of which took longer than 20m before (I never let it
finish) but now takes less than 14s.  Most schedules take less than
10ms.

Pull Request resolved: pytorch#138924
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#138928, pytorch#131762
@github-actions github-actions bot deleted the gh/wconstab/327/head branch December 1, 2024 02:21
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.

Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles.  In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.

ghstack-source-id: 16998f2
Pull Request resolved: pytorch/pytorch#131762
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants