KEMBAR78
[Pipelining] Make PipelineStage support meta initialization by wconstab · Pull Request #136243 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Sep 18, 2024

Stack from ghstack (oldest at bottom):

Avoid allocating memory or dry-running the submodule during stage init.

Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.

Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.

For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.

Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.

TODO:

  • delete 'device' arg from PipelineStage ctor? (move it to inferred from
    args tensors passed to first step call? separate PR.
  • delete 'output_args' from PipelineStage ctor? we don't actually need
    it, but we use it to do shape validation, which is why I didn't remove
    it in this PR. Proposal: leave it until we add lazy shape inference?

Fixes #136225, #136226

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 18, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136243

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 74cfb12 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Sep 18, 2024
Avoid allocating memory or dry-running the submodule during stage init.

Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.

Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.

For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.

Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.

TODO:
- delete 'device' arg from PipelineStage ctor? (move it to inferred from
  args tensors passed to first step call? separate PR.
- delete 'output_args' from PipelineStage ctor? we don't actually need
  it, but we use it to do shape validation, which is why I didn't remove
  it in this PR.  Proposal: leave it until we add lazy shape inference?

Fixes #136225, #136226

ghstack-source-id: 8a359b5
Pull Request resolved: #136243
wconstab added a commit to pytorch/torchtitan that referenced this pull request Sep 18, 2024
Uses meta device for tensors/model used before pipeline splitting.

*Important:*
Relies on pytorch/pytorch#136243 to make PipelineStage avoid
materializing the model and the input/output buffers eagerly.

Relies on existing .to(device) in train.py to finally materialize the
model.

ghstack-source-id: 66fa9f1
Pull Request resolved: #582
@wconstab wconstab added the release notes: distributed (pipeline) release notes category label Sep 18, 2024
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix!

self.inputs_meta = (
(input_args,) if isinstance(input_args, torch.Tensor) else input_args
)
self._configure_outputs_meta(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make output_args required (right now it is optional), otherwise initialization will fail right? To have a better transition, can we still run self.submod() with the input args and leave the output_args to be false. This will fail if the model is not on the same device as the input, but that is okay and we can say we expect them to be both on meta device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right that this makes output_args required. I like your suggestion. If you're willing to let it be an error if users pass real inputs, I would propose to assert the module and inputs are on the same device, and then compute 'outputs_meta' based on inputs.

I was weighing whether to make it a requirement to pass input/module on meta device, but I suppose that is too restrictive. I do want to ensure that outputs_meta is stored on meta to avoid wasting memory. So maybe if the user passes cuda model/inputs I will convert the output to meta. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i added output-shape inference back in, but i made it refuse to do inference on non-meta device, so some user code might still have to switch their inputs to meta. is this ok, or should i make it work for cuda?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably might be easier to just let it work for cuda if the users already have it on cuda. Then you wouldn't have to fix all the tests as you mentioned above. I think the main fix of the PR can just be removing the .to(device) in init which you did, and with that removal now it is up to the users to make sure that model(input_args) works if they pass it in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can I say that I prefer the previous version better? i.e. removal.
(See also my comments about on relaxing device constraint, which requires removal of dryrun.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm in the other boat, I think keeping the dry run makes more sense (as a temporary measure) until we get the lazy init working, otherwise we have to update all the tests to pass in both input_args and output_args and also requires users to do so as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's hard. Maybe at least add with torch.no_grad()?
I am actually not sure what side effect running a module has. Saving grad context, marking requires_grad?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in forward_maybe_with_... we have special logic for DDP/FSDP modules, would this dryrun's escape be an issue? (Or are we just lucky?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm going to leave this for later. its not any worse with this PR and we plan to remove it when adding lazy shape inference. sound ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sounds good.

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm.
We have modified single-stage schedule at this point but not multi-stage ones?

@wconstab
Copy link
Contributor Author

Oh, I should include multi stage schedule in this PR, I'll fix it before landing

@kwen2501
Copy link
Contributor

Confirm fixes #136225

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Sep 20, 2024
Avoid allocating memory or dry-running the submodule during stage init.

Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.

Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.

For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.

Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.

TODO:
- delete 'device' arg from PipelineStage ctor? (move it to inferred from
  args tensors passed to first step call? separate PR.
- delete 'output_args' from PipelineStage ctor? we don't actually need
  it, but we use it to do shape validation, which is why I didn't remove
  it in this PR.  Proposal: leave it until we add lazy shape inference?

Fixes #136225, #136226

ghstack-source-id: 0a452fc
Pull Request resolved: #136243
target: target for the loss function.
losses: a list to store the losses for each microbatch.
"""
if not self._stages_initialized:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, should I move this logic inside _step_microbatches? it seems some of our tests are calling _step_microbatches. Do we allow this or do we require calling step()?

if we don't require calling step(), then should we also move this code inside _step_microbatches?

        self._stage.clear_runtime_states()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't i just do this init inside stage._forward_one_chunk...?

Copy link
Contributor

@kwen2501 kwen2501 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. That would be a good place.
We may have to ask CUDA Graph to come capture after the first step(...) though.
But that may not be a big deal? And the current change moved the prepare from init to step anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

annoying: Stage class doesn't know "n_microbatches", which may have been intentional?

I am thinking about whether it is bad to let the schedule 'register' the number of microbatches during Schedule.init and then stage can use this value later when it performs initialization inside forward_one_chunk / backward_one_chunk?

group: Optional[dist.ProcessGroup] = None,
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
):
assert submodule.device == torch.device(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion breaks tests. i guess itll break some existing usages. Is it a good thing to do to prevent bad practices in the future, or should i relax this and make it work if model+inputs are on cuda?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should relax this. "meta" device is still a high-end thing for most users.
And, is nn.Module.device a real thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, its not always possible to assert this- maybe there is another way to reliably check the device of the module?

AttributeError: 'FSDPSequential' object has no attribute 'device'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, maybe i should just drop the checks and perform the shape inference blindly, but then do a .to(meta) on the outputs. that should make things on par for what howard mentioned, and then we can delete it once we add lazy inference

Copy link
Contributor

@weifengpy weifengpy Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDPSequential is a superset of 'nn.Sequential'. It means vanilla nn.Sequential does not have device (see below for repro). I guess what we really want to check whether model.parameters() are on meta? that should cover all cases

>>> import torch
>>> model = torch.nn.Sequential(torch.nn.Linear(2,2))
>>> model.device
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'Sequential' object has no attribute 'device'

group: Optional[dist.ProcessGroup] = None,
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
):
assert submodule.device == torch.device(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should relax this. "meta" device is still a high-end thing for most users.
And, is nn.Module.device a real thing?

self.inputs_meta = (
(input_args,) if isinstance(input_args, torch.Tensor) else input_args
)
self._configure_outputs_meta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can I say that I prefer the previous version better? i.e. removal.
(See also my comments about on relaxing device constraint, which requires removal of dryrun.)

Comment on lines +1461 to +1465
# TODO, (1) are we deleting output validation when we move to shape inference?
# (2) if not, we should support multiple outputs
assert (
len(outputs_meta) == 1
), f"validation logic assumes single output, got {len(outputs_meta)} outputs "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh then should we disable validation logic (which an add-on protection) till it supports the multi-output case? Multi-output is pretty common. This assert may break a few tracer cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is a new regression, i can remove the assert, but i thought the assert would just make another error more explicit. The old code assumes 'output' is a single value and tries to validate against it. Lets see if CI fails and if not, make another PR to fix validation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yep, if it is not a new assert, then let's see what CI says. :)

[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Sep 20, 2024
Avoid allocating memory or dry-running the submodule during stage init.

Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.

Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.

For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.

Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.

TODO:
- delete 'device' arg from PipelineStage ctor? (move it to inferred from
  args tensors passed to first step call? separate PR.
- delete 'output_args' from PipelineStage ctor? we don't actually need
  it, but we use it to do shape validation, which is why I didn't remove
  it in this PR.  Proposal: leave it until we add lazy shape inference?

Fixes #136225, #136226

ghstack-source-id: 955df68
Pull Request resolved: #136243
self.inputs_meta = (
(input_args,) if isinstance(input_args, torch.Tensor) else input_args
)
self._configure_outputs_meta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sounds good.

Comment on lines +1461 to +1465
# TODO, (1) are we deleting output validation when we move to shape inference?
# (2) if not, we should support multiple outputs
assert (
len(outputs_meta) == 1
), f"validation logic assumes single output, got {len(outputs_meta)} outputs "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yep, if it is not a new assert, then let's see what CI says. :)

@kwen2501
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 20, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Sep 21, 2024
Avoid allocating memory or dry-running the submodule during stage init.

Save user-provided input/output metadata during stage init, to allow
lazily initializing the buffers before the first step call.

Later, we plan to build on top of this to add lazy shape inference
(#130856) so that no input/output shapes are required at stage init.

For now, we require input/output tensors for stage init, but these
should be on meta device and stage should not allocate any real memory.

Note: this needs more thorough testing and review, but it worked on the
torchtitan 3d test.

TODO:
- delete 'device' arg from PipelineStage ctor? (move it to inferred from
  args tensors passed to first step call? separate PR.
- delete 'output_args' from PipelineStage ctor? we don't actually need
  it, but we use it to do shape validation, which is why I didn't remove
  it in this PR.  Proposal: leave it until we add lazy shape inference?

Fixes #136225, #136226

ghstack-source-id: b05610d
Pull Request resolved: #136243
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@kwen2501
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

wconstab added a commit to pytorch/torchtitan that referenced this pull request Sep 26, 2024
Uses meta device for tensors/model used before pipeline splitting.

*Important:*
Relies on pytorch/pytorch#136243 to make PipelineStage avoid
materializing the model and the input/output buffers eagerly.

Relies on existing .to(device) in train.py to finally materialize the
model.

ghstack-source-id: c15282c
Pull Request resolved: #588
wconstab added a commit to pytorch/torchtitan that referenced this pull request Sep 26, 2024
Uses meta device for tensors/model used before pipeline splitting.

*Important:*
Relies on pytorch/pytorch#136243 to make PipelineStage avoid
materializing the model and the input/output buffers eagerly.

Relies on existing .to(device) in train.py to finally materialize the
model.

ghstack-source-id: c15282c
Pull Request resolved: #588
@github-actions github-actions bot deleted the gh/wconstab/335/head branch October 22, 2024 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants