-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[PP] Add eval() API to schedule #157795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PP] Add eval() API to schedule #157795
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157795
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit eb5e0a8 with merge base 9345279 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This change allows PP schedules to run under `with torch.no_grad():`
## Context
Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.
However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
```python
with torch.no_grad():
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
schedule.step(target=target, losses=losses)
else:
schedule.step()
```
(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)
cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k
[ghstack-poisoned]
|
question on the approach: is your goal to make it so that schedule.step() will behave differently under if so, i wonder if such polymorphic behavior of .step() is really needed. Should we instead just add an api like
|
Yep, that's correct
This would also work. We could do this instead but the minor API difference is you would need replace all So I guess we have 3 options:
Any preferences for these options? |
This change allows PP schedules to run under `with torch.no_grad():`
## Context
Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.
However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
```python
with torch.no_grad():
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
schedule.step(target=target, losses=losses)
else:
schedule.step()
```
(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)
cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k
[ghstack-poisoned]
|
I think my major preference would be to have an explicit eval() api. Mainly, becuase its clear what it does and we can document it, and also remove the forwardonly schedule. As for what to do with no_grad and step, i think semantically what 'step' means is "run a train step" and doing a train step with no grad is kind of nonsense. so i might lean towards raising a nice error there. |
This change allows PP schedules to run under `with torch.no_grad():`
## Context
Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.
However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
```python
with torch.no_grad():
if self.rank == 0:
schedule.step(x)
elif self.rank == self.world_size - 1:
losses = []
schedule.step(target=target, losses=losses)
else:
schedule.step()
```
(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)
cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k
[ghstack-poisoned]
These change add an `eval()` API to PP schedules
## Context
Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.
However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
```python
if self.rank == 0:
schedule.eval(x)
elif self.rank == self.world_size - 1:
losses = []
schedule.eval(target=target, losses=losses)
else:
schedule.eval()
```
TODO:
- in later PRs, we will deprecate the `_ScheduleForwardOnly`
cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k
[ghstack-poisoned]
| after the last backward. | ||
| """ | ||
| # skip backward computation if backward is not enabled | ||
| if not self.has_backward: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, am i right that the way eval() works is that it calls the regular schedule step() but makes backward steps in the schedule into no-ops?
intuitively i was thinking what you'd do is just have eval call into the 'schedule_forward_only' logic,
however, this approach might be fine too. Is there some advantage to doing the forward operations in a schedule-specific way? and is there a risk here that we still accidentally do some extra backward communication because we forgot to disable it, or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am i right that the way eval() works is that it calls the regular schedule step() but makes backward steps in the schedule into no-ops?
Yep that's right.
The advantage of sharing is we can reuse pre and post processing logic we use in step() (microbatch splitting and output merging), and logging that we have in original step(). Also because we want to consolidate around the _PipelineScheduleRuntime having one method that performs the actual "execution" is easier to reason about I think.
There is a risk as you mentioned of accidentally introducing backward comm, but the hopefully tests should cover this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. i guess i'd still prefer if we could keep eval and step separate and just share helpers/utils for microbatching and output merging. But i agree this is a good reason to use step for now and we could do that later as a cleanup if we want.
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With recent api change to pipeline schedule pytorch/pytorch#157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5 12 49 PM" src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7" /> | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334" alt="Screenshot 2025-07-29 at 5 17 18 PM" src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43" /> | | FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335" alt="Screenshot 2025-07-29 at 5 15 53 PM" src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599" /> | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964" height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM" src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e" /> | | FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329" alt="Screenshot 2025-07-29 at 5 49 36 PM" src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247" /> | FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330" alt="Screenshot 2025-07-29 at 5 54 55 PM" src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce" /> | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960" height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM" src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690" />
With recent api change to pipeline schedule pytorch/pytorch#157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5 12 49 PM" src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7" /> | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334" alt="Screenshot 2025-07-29 at 5 17 18 PM" src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43" /> | | FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335" alt="Screenshot 2025-07-29 at 5 15 53 PM" src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599" /> | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964" height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM" src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e" /> | | FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329" alt="Screenshot 2025-07-29 at 5 49 36 PM" src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247" /> | FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330" alt="Screenshot 2025-07-29 at 5 54 55 PM" src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce" /> | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960" height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM" src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690" />
With recent api change to pipeline schedule pytorch/pytorch#157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5 12 49 PM" src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7" /> | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334" alt="Screenshot 2025-07-29 at 5 17 18 PM" src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43" /> | | FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335" alt="Screenshot 2025-07-29 at 5 15 53 PM" src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599" /> | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964" height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM" src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e" /> | | FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329" alt="Screenshot 2025-07-29 at 5 49 36 PM" src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247" /> | FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330" alt="Screenshot 2025-07-29 at 5 54 55 PM" src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce" /> | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960" height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM" src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690" />
With recent api change to pipeline schedule pytorch/pytorch#157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5 12 49 PM" src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7" /> | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334" alt="Screenshot 2025-07-29 at 5 17 18 PM" src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43" /> | | FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335" alt="Screenshot 2025-07-29 at 5 15 53 PM" src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599" /> | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964" height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM" src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e" /> | | FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329" alt="Screenshot 2025-07-29 at 5 49 36 PM" src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247" /> | FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330" alt="Screenshot 2025-07-29 at 5 54 55 PM" src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce" /> | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960" height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM" src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690" />
Stack from ghstack (oldest at bottom):
These change add an
eval()API to PP schedulesContext
Currently, you can run "Forward only" for a schedule in two ways:
_ScheduleForwardOnlyloss_fnin schedule constructor, and no backward computations will be executed.However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.
TODO:
_ScheduleForwardOnlycc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k