KEMBAR78
[Pipelining] Allow non-0 stages to accept kwargs by kwen2501 · Pull Request #136416 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 23, 2024

Stack from ghstack (oldest at bottom):

For supporting usage case in torchchat:
all non-0 stages requires input_pos and cache_lane.

kwargs = {"input_pos": input_pos, "cache_lane": lane}

if pp_rank == first_pp_rank:
    output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
    output = decorder.step(**kwargs)
else:  # middle pp ranks
    decorder.step(**kwargs)

The forward_one_chunk code today hard sets {} as kwarg for non-0 stages, hence cannot support the above use case.

cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136416

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9d17b53 with merge base 3bc073d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 23, 2024
@kwen2501 kwen2501 added release notes: distributed (pipeline) release notes category module: pipelining Pipeline Parallelism labels Sep 23, 2024
@kwen2501 kwen2501 requested review from H-Huang and wconstab and removed request for H-Huang September 23, 2024 06:53
For supporting usage case in torchchat: 
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}

if pp_rank == first_pp_rank:
    output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
    output = decorder.step(**kwargs)
else:  # middle pp ranks
    decorder.step(**kwargs)
```

The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
For supporting usage case in torchchat: 
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}

if pp_rank == first_pp_rank:
    output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
    output = decorder.step(**kwargs)
else:  # middle pp ranks
    decorder.step(**kwargs)
```

The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 23, 2024
ghstack-source-id: aba7fe3
Pull Request resolved: #136416

if self.rank == 0:
with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"):
_run_step(torch.randn(batch_size + 1, d_hid, device=self.device))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come the shape checking got disabled even for args? i though from reading the stage code change it would only change the validation of kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disabled the check on args in the library code when kwargs are present.
So I have to remove the assertRaisesRegex test here (otherwise it catches nothing).

This is the comment in library code:

            # TODO- need a mapping of kwarg to position in self.args_recv_info
            # Without it, we are not 100% sure how to match the args and
            # expected_args.

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is an improvement. still feels like we haven't quite nailed the args/kwargs handling though (esp. validation).

@kwen2501
Copy link
Contributor Author

Yeah, this is a "fast food" PR. Thanks for the review.

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
For supporting usage case in torchchat:
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}

if pp_rank == first_pp_rank:
    output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
    output = decorder.step(**kwargs)
else:  # middle pp ranks
    decorder.step(**kwargs)
```

The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.

Pull Request resolved: pytorch#136416
Approved by: https://github.com/wconstab
@github-actions github-actions bot deleted the gh/kwen2501/63/head branch October 25, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants