-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Pipelining] Make PipelineStage support meta initialization #136243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136243
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 74cfb12 with merge base failed to retrieve merge base, please contact dev infra: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Avoid allocating memory or dry-running the submodule during stage init. Save user-provided input/output metadata during stage init, to allow lazily initializing the buffers before the first step call. Later, we plan to build on top of this to add lazy shape inference (#130856) so that no input/output shapes are required at stage init. For now, we require input/output tensors for stage init, but these should be on meta device and stage should not allocate any real memory. Note: this needs more thorough testing and review, but it worked on the torchtitan 3d test. TODO: - delete 'device' arg from PipelineStage ctor? (move it to inferred from args tensors passed to first step call? separate PR. - delete 'output_args' from PipelineStage ctor? we don't actually need it, but we use it to do shape validation, which is why I didn't remove it in this PR. Proposal: leave it until we add lazy shape inference? Fixes #136225, #136226 ghstack-source-id: 8a359b5 Pull Request resolved: #136243
Uses meta device for tensors/model used before pipeline splitting. *Important:* Relies on pytorch/pytorch#136243 to make PipelineStage avoid materializing the model and the input/output buffers eagerly. Relies on existing .to(device) in train.py to finally materialize the model. ghstack-source-id: 66fa9f1 Pull Request resolved: #582
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
| self.inputs_meta = ( | ||
| (input_args,) if isinstance(input_args, torch.Tensor) else input_args | ||
| ) | ||
| self._configure_outputs_meta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will make output_args required (right now it is optional), otherwise initialization will fail right? To have a better transition, can we still run self.submod() with the input args and leave the output_args to be false. This will fail if the model is not on the same device as the input, but that is okay and we can say we expect them to be both on meta device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right that this makes output_args required. I like your suggestion. If you're willing to let it be an error if users pass real inputs, I would propose to assert the module and inputs are on the same device, and then compute 'outputs_meta' based on inputs.
I was weighing whether to make it a requirement to pass input/module on meta device, but I suppose that is too restrictive. I do want to ensure that outputs_meta is stored on meta to avoid wasting memory. So maybe if the user passes cuda model/inputs I will convert the output to meta. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, i added output-shape inference back in, but i made it refuse to do inference on non-meta device, so some user code might still have to switch their inputs to meta. is this ok, or should i make it work for cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It probably might be easier to just let it work for cuda if the users already have it on cuda. Then you wouldn't have to fix all the tests as you mentioned above. I think the main fix of the PR can just be removing the .to(device) in init which you did, and with that removal now it is up to the users to make sure that model(input_args) works if they pass it in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, can I say that I prefer the previous version better? i.e. removal.
(See also my comments about on relaxing device constraint, which requires removal of dryrun.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm in the other boat, I think keeping the dry run makes more sense (as a temporary measure) until we get the lazy init working, otherwise we have to update all the tests to pass in both input_args and output_args and also requires users to do so as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's hard. Maybe at least add with torch.no_grad()?
I am actually not sure what side effect running a module has. Saving grad context, marking requires_grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, in forward_maybe_with_... we have special logic for DDP/FSDP modules, would this dryrun's escape be an issue? (Or are we just lucky?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm going to leave this for later. its not any worse with this PR and we plan to remove it when adding lazy shape inference. sound ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm.
We have modified single-stage schedule at this point but not multi-stage ones?
|
Oh, I should include multi stage schedule in this PR, I'll fix it before landing |
|
Confirm fixes #136225 |
Avoid allocating memory or dry-running the submodule during stage init. Save user-provided input/output metadata during stage init, to allow lazily initializing the buffers before the first step call. Later, we plan to build on top of this to add lazy shape inference (#130856) so that no input/output shapes are required at stage init. For now, we require input/output tensors for stage init, but these should be on meta device and stage should not allocate any real memory. Note: this needs more thorough testing and review, but it worked on the torchtitan 3d test. TODO: - delete 'device' arg from PipelineStage ctor? (move it to inferred from args tensors passed to first step call? separate PR. - delete 'output_args' from PipelineStage ctor? we don't actually need it, but we use it to do shape validation, which is why I didn't remove it in this PR. Proposal: leave it until we add lazy shape inference? Fixes #136225, #136226 ghstack-source-id: 0a452fc Pull Request resolved: #136243
| target: target for the loss function. | ||
| losses: a list to store the losses for each microbatch. | ||
| """ | ||
| if not self._stages_initialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, should I move this logic inside _step_microbatches? it seems some of our tests are calling _step_microbatches. Do we allow this or do we require calling step()?
if we don't require calling step(), then should we also move this code inside _step_microbatches?
self._stage.clear_runtime_states()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't i just do this init inside stage._forward_one_chunk...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. That would be a good place.
We may have to ask CUDA Graph to come capture after the first step(...) though.
But that may not be a big deal? And the current change moved the prepare from init to step anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
annoying: Stage class doesn't know "n_microbatches", which may have been intentional?
I am thinking about whether it is bad to let the schedule 'register' the number of microbatches during Schedule.init and then stage can use this value later when it performs initialization inside forward_one_chunk / backward_one_chunk?
| group: Optional[dist.ProcessGroup] = None, | ||
| dw_builder: Optional[Callable[[], Callable[..., None]]] = None, | ||
| ): | ||
| assert submodule.device == torch.device( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assertion breaks tests. i guess itll break some existing usages. Is it a good thing to do to prevent bad practices in the future, or should i relax this and make it work if model+inputs are on cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should relax this. "meta" device is still a high-end thing for most users.
And, is nn.Module.device a real thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, its not always possible to assert this- maybe there is another way to reliably check the device of the module?
AttributeError: 'FSDPSequential' object has no attribute 'device'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, maybe i should just drop the checks and perform the shape inference blindly, but then do a .to(meta) on the outputs. that should make things on par for what howard mentioned, and then we can delete it once we add lazy inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FSDPSequential is a superset of 'nn.Sequential'. It means vanilla nn.Sequential does not have device (see below for repro). I guess what we really want to check whether model.parameters() are on meta? that should cover all cases
>>> import torch
>>> model = torch.nn.Sequential(torch.nn.Linear(2,2))
>>> model.device
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1931, in __getattr__
raise AttributeError(
AttributeError: 'Sequential' object has no attribute 'device'
| group: Optional[dist.ProcessGroup] = None, | ||
| dw_builder: Optional[Callable[[], Callable[..., None]]] = None, | ||
| ): | ||
| assert submodule.device == torch.device( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should relax this. "meta" device is still a high-end thing for most users.
And, is nn.Module.device a real thing?
| self.inputs_meta = ( | ||
| (input_args,) if isinstance(input_args, torch.Tensor) else input_args | ||
| ) | ||
| self._configure_outputs_meta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, can I say that I prefer the previous version better? i.e. removal.
(See also my comments about on relaxing device constraint, which requires removal of dryrun.)
| # TODO, (1) are we deleting output validation when we move to shape inference? | ||
| # (2) if not, we should support multiple outputs | ||
| assert ( | ||
| len(outputs_meta) == 1 | ||
| ), f"validation logic assumes single output, got {len(outputs_meta)} outputs " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh then should we disable validation logic (which an add-on protection) till it supports the multi-output case? Multi-output is pretty common. This assert may break a few tracer cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is a new regression, i can remove the assert, but i thought the assert would just make another error more explicit. The old code assumes 'output' is a single value and tries to validate against it. Lets see if CI fails and if not, make another PR to fix validation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yep, if it is not a new assert, then let's see what CI says. :)
Avoid allocating memory or dry-running the submodule during stage init. Save user-provided input/output metadata during stage init, to allow lazily initializing the buffers before the first step call. Later, we plan to build on top of this to add lazy shape inference (#130856) so that no input/output shapes are required at stage init. For now, we require input/output tensors for stage init, but these should be on meta device and stage should not allocate any real memory. Note: this needs more thorough testing and review, but it worked on the torchtitan 3d test. TODO: - delete 'device' arg from PipelineStage ctor? (move it to inferred from args tensors passed to first step call? separate PR. - delete 'output_args' from PipelineStage ctor? we don't actually need it, but we use it to do shape validation, which is why I didn't remove it in this PR. Proposal: leave it until we add lazy shape inference? Fixes #136225, #136226 ghstack-source-id: 955df68 Pull Request resolved: #136243
| self.inputs_meta = ( | ||
| (input_args,) if isinstance(input_args, torch.Tensor) else input_args | ||
| ) | ||
| self._configure_outputs_meta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, sounds good.
| # TODO, (1) are we deleting output validation when we move to shape inference? | ||
| # (2) if not, we should support multiple outputs | ||
| assert ( | ||
| len(outputs_meta) == 1 | ||
| ), f"validation logic assumes single output, got {len(outputs_meta)} outputs " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Yep, if it is not a new assert, then let's see what CI says. :)
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Avoid allocating memory or dry-running the submodule during stage init. Save user-provided input/output metadata during stage init, to allow lazily initializing the buffers before the first step call. Later, we plan to build on top of this to add lazy shape inference (#130856) so that no input/output shapes are required at stage init. For now, we require input/output tensors for stage init, but these should be on meta device and stage should not allocate any real memory. Note: this needs more thorough testing and review, but it worked on the torchtitan 3d test. TODO: - delete 'device' arg from PipelineStage ctor? (move it to inferred from args tensors passed to first step call? separate PR. - delete 'output_args' from PipelineStage ctor? we don't actually need it, but we use it to do shape validation, which is why I didn't remove it in this PR. Proposal: leave it until we add lazy shape inference? Fixes #136225, #136226 ghstack-source-id: b05610d Pull Request resolved: #136243
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Uses meta device for tensors/model used before pipeline splitting. *Important:* Relies on pytorch/pytorch#136243 to make PipelineStage avoid materializing the model and the input/output buffers eagerly. Relies on existing .to(device) in train.py to finally materialize the model. ghstack-source-id: c15282c Pull Request resolved: #588
Uses meta device for tensors/model used before pipeline splitting. *Important:* Relies on pytorch/pytorch#136243 to make PipelineStage avoid materializing the model and the input/output buffers eagerly. Relies on existing .to(device) in train.py to finally materialize the model. ghstack-source-id: c15282c Pull Request resolved: #588
Stack from ghstack (oldest at bottom):
Avoid allocating memory or dry-running the submodule during stage init.
Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.
Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.
For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.
Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.
TODO:
args tensors passed to first step call? separate PR.
it, but we use it to do shape validation, which is why I didn't remove
it in this PR. Proposal: leave it until we add lazy shape inference?
Fixes #136225, #136226
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o