KEMBAR78
[dynamo] FSDP + AC + torch.compile by anijain2305 · Pull Request #103953 · pytorch/pytorch · GitHub
Skip to content

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Jun 21, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103953

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7c3c297:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Jun 21, 2023
ghstack-source-id: 0632b73
Pull Request resolved: #103953
anijain2305 added a commit that referenced this pull request Jun 21, 2023
ghstack-source-id: 0632b73
Pull Request resolved: #103953
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
@anijain2305 anijain2305 changed the title [WIP][DONT REVIEW] FSDP + AC + torch.compile [dynamo] FSDP + AC + torch.compile Jun 21, 2023
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_activation_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to also test the native torch.utils one too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats already tested in test/dynamo/test_activation_checkpointing.py

They are also tested heavily on HF models - more data here - #102935

with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_toy_model_for_activation_checkpointing(f"cuda:{self.rank}")
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any valid cases where we don't use the same policy for fsdp/ac that we need to test? cc @awgu


gm.__name__ = next_name
src = NNModuleSource(GetItemSource(self.source, next_name))
if self.source.guard_source().is_fsdp_module():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does 'add_subgraph' do generally? are there more than one subgraph per higher order operator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just a util that inserts the subgraph into the output graph module of Dynamo. Yes, there is no limit on how many subgraphs we can have. torch.cond, for example, has 2 - one for true branch and one for false.

In the case of activation checkpointing, there is only 1.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it looks good. Would like to get @wanchaol to also review

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model = FSDP(
copy.deepcopy(model),
auto_wrap_policy=wrap_policy,
use_orig_params=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we paramterize the tests and test both use_orig_params == True/False similar to https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_checkpoint.py#L137?

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jun 22, 2023
ghstack-source-id: eae969f
Pull Request resolved: #103953
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jun 22, 2023
ghstack-source-id: c9b33ee
Pull Request resolved: #103953
k: variables.ConstantVariable(v) for k, v in self.value.keywords.items()
}
partial_kwargs.update(kwargs)
if requires_higher_order_op(self.value.func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good opportunity to use the walrus operator.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 23, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jun 23, 2023
ghstack-source-id: 4d2437d
Pull Request resolved: #103953
@anijain2305 anijain2305 added the topic: not user facing topic category label Jun 23, 2023
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Jun 23, 2023
ghstack-source-id: f2716de
Pull Request resolved: #103953
@anijain2305
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@voznesenskym
Copy link
Collaborator

Do you mind testing this against #103711 ? I know its an onerous ask, but I want to make sure we also keep in mind the relative future direction of PT2 + FSDP for things like this.

@anijain2305
Copy link
Contributor Author

Do you mind testing this against #103711 ? I know its an onerous ask, but I want to make sure we also keep in mind the relative future direction of PT2 + FSDP for things like this.

Oh missed this message. Will certainly do in next couple of weeks. Trying to wrap up some work, and will pick this up.

@facebook-github-bot facebook-github-bot deleted the gh/anijain2305/65/head branch June 27, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants