KEMBAR78
[RFC] Don't materialize ignored modules for FSDP by rohan-varma · Pull Request #108032 · pytorch/pytorch · GitHub
Skip to content

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Aug 28, 2023

Stack from ghstack (oldest at bottom):

Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: D48722046

Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108032

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Merge Blocking SEVs

There is 1 active merge blocking SEVs. Please view them below:

If you must merge, use @pytorchbot merge -f.

❌ 4 New Failures, 1 Unrelated Failure

As of commit 017d4ae with merge base a20fac8 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

device_id=self.rank,
ignored_modules=[m.a],
use_orig_params=True,
param_init_fn=lambda m: m.cuda(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this lambda m: m.cuda() would do some repeated checks trying to move modules to CUDA since param_init_fn would be called on every module. This should just lead to some CPU overhead since copying to CUDA is a no-op if already on CUDA.

As a tiny nit, the variable shadowing of m is also a bit precarious.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should cancel the approve. The unit test is failing.

return self._apply(lambda t: t.cuda(device))
NotImplementedError: Cannot copy out of meta tensor; no data!

I think you need a different param_init_fn.

@rohan-varma
Copy link
Contributor Author

That's weird, I feel like I ran the test before sending the PR...

Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)

[ghstack-poisoned]
@rohan-varma rohan-varma requested a review from awgu September 1, 2023 00:43
Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)

[ghstack-poisoned]
Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 5, 2023
Since these are ignored by FSDP, don't move them.

Differential Revision: [D48727044](https://our.internmc.facebook.com/intern/diff/D48727044/)
Pull Request resolved: #108033
Approved by: https://github.com/awgu
ghstack dependencies: #108032
pytorchmergebot pushed a commit that referenced this pull request Sep 5, 2023
…NTRANT (#108435)

We should use no_reentrant. There are a lot of users of this API, but
it is in a prototype state so should be fine to change.

Differential Revision: [D48898148](https://our.internmc.facebook.com/intern/diff/D48898148/)
Pull Request resolved: #108435
Approved by: https://github.com/awgu
ghstack dependencies: #108032, #108033
@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/734/head branch September 9, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants