-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Pipelining] Allow non-0 stages to accept kwargs #136416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136416
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9d17b53 with merge base 3bc073d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For supporting usage case in torchchat:
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}
if pp_rank == first_pp_rank:
output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decorder.step(**kwargs)
else: # middle pp ranks
decorder.step(**kwargs)
```
The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
[ghstack-poisoned]
For supporting usage case in torchchat:
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}
if pp_rank == first_pp_rank:
output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decorder.step(**kwargs)
else: # middle pp ranks
decorder.step(**kwargs)
```
The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
[ghstack-poisoned]
|
|
||
| if self.rank == 0: | ||
| with self.assertRaisesRegex(PipeliningShapeError, "shape mismatch"): | ||
| _run_step(torch.randn(batch_size + 1, d_hid, device=self.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come the shape checking got disabled even for args? i though from reading the stage code change it would only change the validation of kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disabled the check on args in the library code when kwargs are present.
So I have to remove the assertRaisesRegex test here (otherwise it catches nothing).
This is the comment in library code:
# TODO- need a mapping of kwarg to position in self.args_recv_info
# Without it, we are not 100% sure how to match the args and
# expected_args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is an improvement. still feels like we haven't quite nailed the args/kwargs handling though (esp. validation).
|
Yeah, this is a "fast food" PR. Thanks for the review. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
For supporting usage case in torchchat:
all non-0 stages requires `input_pos` and `cache_lane`.
```
kwargs = {"input_pos": input_pos, "cache_lane": lane}
if pp_rank == first_pp_rank:
output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decorder.step(**kwargs)
else: # middle pp ranks
decorder.step(**kwargs)
```
The `forward_one_chunk` code today hard sets `{}` as kwarg for non-0 stages, hence cannot support the above use case.
Pull Request resolved: pytorch#136416
Approved by: https://github.com/wconstab
Stack from ghstack (oldest at bottom):
For supporting usage case in torchchat:
all non-0 stages requires
input_posandcache_lane.The
forward_one_chunkcode today hard sets{}as kwarg for non-0 stages, hence cannot support the above use case.cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o