-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Allow schedules to run with single stage #138925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138925
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1388eb8 with merge base 2922b9f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Ran into issues (#138863) when adding a Schedule with single stage for zero bubble, adding code to support this mostly for test purposes cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Ran into issues (#138863) when adding a Schedule with a single stage, so adding code to support this edge case (mostly for test purposes) cc awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to fix this another way such that stage 0 still computes input grad separately?
| ) | ||
| grads_input = [] | ||
| param_groups = [] | ||
| # Skip the backward for the first stage since we will perform the weight update with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that stage 0 will never run separate W/I computations even in multi stage pipelines?
I think this is a significant problem since in ZB it is more common to use separate W/I for earlier stages than late stages. Last stage may have almost entirely merged full-backward but first stage may need mostly separated ones to fill bubbles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stage 0 still computes W/I but now the I is like a no-op since the real work is in done in W. Typically the input grad would not be computed for stage 0 anyways since the inputs do not require gradients and this skips the .grad() call entirely.
This is only for the case of W/I split, for the full B the backward execution will remain the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly makes sense, I agree it is pointless to compute dI on stage 0. I need to revisit how the schedules are designed because I thought separate I was a common thing for stage 0 of zb schedules.
| last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] | ||
| else: | ||
| # For backwards are split into weight and input, we will see twice as many bwd_chunks | ||
| # -1 because we skip the first bwd_chunk backward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For another PR.. but in case you didn't see my comment on my own PR for merge bw, this logic will have to be rewritten since any stage may have some mix of I, W and B operations so we can't expect to do it by counting and expecting round numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, we can definitely rewrite this logic!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Ran into issues (pytorch#138863) when adding a Schedule with a single stage, so adding code to support this edge case (mostly for test purposes) Pull Request resolved: pytorch#138925 Approved by: https://github.com/wconstab
Stack from ghstack (oldest at bottom):
Ran into issues (#138863) when adding a Schedule with a single stage, so adding code to support this edge case (mostly for test purposes)
cc @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o