KEMBAR78
[PP] Add eval() API to schedule by H-Huang · Pull Request #157795 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Jul 8, 2025

Stack from ghstack (oldest at bottom):

These change add an eval() API to PP schedules

Context

Currently, you can run "Forward only" for a schedule in two ways:

  1. Use a custom schedule _ScheduleForwardOnly
  2. Do not pass in loss_fn in schedule constructor, and no backward computations will be executed.

However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.

if self.rank == 0:
    schedule.eval(x)
elif self.rank == self.world_size - 1:
    losses = []
    schedule.eval(target=target, losses=losses)
else:
    schedule.eval()

TODO:

  • in later PRs, we will deprecate the _ScheduleForwardOnly

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 8, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157795

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit eb5e0a8 with merge base 9345279 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This change allows PP schedules to run under `with torch.no_grad():`

## Context

Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.

However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.

```python
with torch.no_grad():
    if self.rank == 0:
        schedule.step(x)
    elif self.rank == self.world_size - 1:
        losses = []
        schedule.step(target=target, losses=losses)
    else:
        schedule.step()
```

(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k 

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 2874b9d
Pull Request resolved: #157795
@wconstab
Copy link
Contributor

wconstab commented Jul 9, 2025

question on the approach:

is your goal to make it so that schedule.step() will behave differently under no_grad() and not run backwards?

if so, i wonder if such polymorphic behavior of .step() is really needed. Should we instead just add an api like eval() to the Schedule class, and have this API be equivalent to running ScheduleForwardOnly?

  • also, i would be in favor of having this convenience method on all the scheudles and getting rid of ScheduleForwardOnly class

@H-Huang
Copy link
Member Author

H-Huang commented Jul 9, 2025

@wconstab

is your goal to make it so that schedule.step() will behave differently under no_grad() and not run backwards?

Yep, that's correct

Should we instead just add an api like eval() to the Schedule class, and have this API be equivalent to running ScheduleForwardOnly?

This would also work. We could do this instead but the minor API difference is you would need replace all step() calls in your code with eval(). I figured that we also would want to support no_grad() since currently it will raise an exception like "[X] for chunk X has gradients None and is expecting to send gradients to stage X" if run under no_grad() so this is fixing two issues.

So I guess we have 3 options:

  1. Support step with torch.no_grad() and add eval() API (they will both functionally do the same thing)
  2. Add eval() API, add a more user friendly error message in step() when using torch.no_grad()
  3. Only support step with torch.no_grad() (current implementation)

Any preferences for these options?

This change allows PP schedules to run under `with torch.no_grad():`

## Context

Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.

However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.

```python
with torch.no_grad():
    if self.rank == 0:
        schedule.step(x)
    elif self.rank == self.world_size - 1:
        losses = []
        schedule.step(target=target, losses=losses)
    else:
        schedule.step()
```

(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k 

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 9, 2025
ghstack-source-id: 6214292
Pull Request resolved: #157795
@wconstab
Copy link
Contributor

I think my major preference would be to have an explicit eval() api. Mainly, becuase its clear what it does and we can document it, and also remove the forwardonly schedule.

As for what to do with no_grad and step, i think semantically what 'step' means is "run a train step" and doing a train step with no grad is kind of nonsense. so i might lean towards raising a nice error there.

@H-Huang H-Huang changed the title [PP] Allow schedules to run under torch.no_grad() [PP] Add eval() API to schedule Jul 15, 2025
This change allows PP schedules to run under `with torch.no_grad():`

## Context

Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.

However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.

```python
with torch.no_grad():
    if self.rank == 0:
        schedule.step(x)
    elif self.rank == self.world_size - 1:
        losses = []
        schedule.step(target=target, losses=losses)
    else:
        schedule.step()
```

(And we may want to deprecate the `_ScheduleForwardOnly`? open to discussion)

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k 

[ghstack-poisoned]
These change add an `eval()` API to PP schedules

## Context

Currently, you can run "Forward only" for a schedule in two ways:
1. Use a custom schedule `_ScheduleForwardOnly`
2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed.

However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this.

```python
if self.rank == 0:
    schedule.eval(x)
elif self.rank == self.world_size - 1:
    losses = []
    schedule.eval(target=target, losses=losses)
else:
    schedule.eval()
```

TODO:
- in later PRs, we will deprecate the `_ScheduleForwardOnly`

cc awgu wanchaol fegin fduwjj wz337 wconstab d4l3k 

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 15, 2025
ghstack-source-id: 9b2bba3
Pull Request resolved: #157795
after the last backward.
"""
# skip backward computation if backward is not enabled
if not self.has_backward:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, am i right that the way eval() works is that it calls the regular schedule step() but makes backward steps in the schedule into no-ops?

intuitively i was thinking what you'd do is just have eval call into the 'schedule_forward_only' logic,

however, this approach might be fine too. Is there some advantage to doing the forward operations in a schedule-specific way? and is there a risk here that we still accidentally do some extra backward communication because we forgot to disable it, or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am i right that the way eval() works is that it calls the regular schedule step() but makes backward steps in the schedule into no-ops?

Yep that's right.

The advantage of sharing is we can reuse pre and post processing logic we use in step() (microbatch splitting and output merging), and logging that we have in original step(). Also because we want to consolidate around the _PipelineScheduleRuntime having one method that performs the actual "execution" is easier to reason about I think.

There is a risk as you mentioned of accidentally introducing backward comm, but the hopefully tests should cover this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. i guess i'd still prefer if we could keep eval and step separate and just share helpers/utils for microbatching and output merging. But i agree this is a good reason to use step for now and we could do that later as a cleanup if we want.

@H-Huang
Copy link
Member Author

H-Huang commented Jul 16, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Jul 31, 2025
With recent api change to pipeline schedule
pytorch/pytorch#157795, we can now schedule
forward pass and calculate loss, allowing us to use validation and pp
together.

To test correctness we train from a seed checkpoint with training.seed
and training.determinism set with varying degrees of parallelism and
different pipeline schedules to compare if loss remains the same:

| Parallelism | Loss |
| --- | --- |
| FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5
12 49 PM"
src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7"
/> |
| FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334"
alt="Screenshot 2025-07-29 at 5 17 18 PM"
src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43"
/> |
| FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335"
alt="Screenshot 2025-07-29 at 5 15 53 PM"
src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599"
/> |
| FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964"
height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM"
src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e"
/> |
| FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329"
alt="Screenshot 2025-07-29 at 5 49 36 PM"
src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247"
/>
| FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330"
alt="Screenshot 2025-07-29 at 5 54 55 PM"
src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce"
/>
| FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960"
height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM"
src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690"
/>
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
With recent api change to pipeline schedule
pytorch/pytorch#157795, we can now schedule
forward pass and calculate loss, allowing us to use validation and pp
together.

To test correctness we train from a seed checkpoint with training.seed
and training.determinism set with varying degrees of parallelism and
different pipeline schedules to compare if loss remains the same:

| Parallelism | Loss |
| --- | --- |
| FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5
12 49 PM"
src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7"
/> |
| FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334"
alt="Screenshot 2025-07-29 at 5 17 18 PM"
src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43"
/> |
| FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335"
alt="Screenshot 2025-07-29 at 5 15 53 PM"
src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599"
/> |
| FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964"
height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM"
src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e"
/> |
| FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329"
alt="Screenshot 2025-07-29 at 5 49 36 PM"
src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247"
/>
| FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330"
alt="Screenshot 2025-07-29 at 5 54 55 PM"
src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce"
/>
| FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960"
height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM"
src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690"
/>
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
With recent api change to pipeline schedule
pytorch/pytorch#157795, we can now schedule
forward pass and calculate loss, allowing us to use validation and pp
together.

To test correctness we train from a seed checkpoint with training.seed
and training.determinism set with varying degrees of parallelism and
different pipeline schedules to compare if loss remains the same:

| Parallelism | Loss |
| --- | --- |
| FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5
12 49 PM"
src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7"
/> |
| FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334"
alt="Screenshot 2025-07-29 at 5 17 18 PM"
src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43"
/> |
| FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335"
alt="Screenshot 2025-07-29 at 5 15 53 PM"
src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599"
/> |
| FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964"
height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM"
src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e"
/> |
| FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329"
alt="Screenshot 2025-07-29 at 5 49 36 PM"
src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247"
/>
| FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330"
alt="Screenshot 2025-07-29 at 5 54 55 PM"
src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce"
/>
| FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960"
height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM"
src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690"
/>
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
With recent api change to pipeline schedule
pytorch/pytorch#157795, we can now schedule
forward pass and calculate loss, allowing us to use validation and pp
together.

To test correctness we train from a seed checkpoint with training.seed
and training.determinism set with varying degrees of parallelism and
different pipeline schedules to compare if loss remains the same:

| Parallelism | Loss |
| --- | --- |
| FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5
12 49 PM"
src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7"
/> |
| FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334"
alt="Screenshot 2025-07-29 at 5 17 18 PM"
src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43"
/> |
| FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335"
alt="Screenshot 2025-07-29 at 5 15 53 PM"
src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599"
/> |
| FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964"
height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM"
src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e"
/> |
| FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329"
alt="Screenshot 2025-07-29 at 5 49 36 PM"
src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247"
/>
| FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330"
alt="Screenshot 2025-07-29 at 5 54 55 PM"
src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce"
/>
| FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960"
height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM"
src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690"
/>
@github-actions github-actions bot deleted the gh/H-Huang/192/head branch August 16, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants