KEMBAR78
Refactor optional graph module into CompiledFxGraphConstants by jamesjwu · Pull Request #141897 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Dec 2, 2024

Stack from ghstack (oldest at bottom):

FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:

  • It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.

  • It refactors the code to be way more clear about the constants we're using and when we're using them.

Basically, there are two possible sets of constants we can grab from the compiled fx graph.

  1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
  2. If freezing is turned on, we save the names of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.

We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.

This makes it clear that the gm passed to AOTAutogradCache is not part of post compile, only the cache key generated from it is.

The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.

Will add a specific AOTAutogradCache to confirm we bypass freezing.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141897

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit a6f565d with merge base 920e436 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jamesjwu jamesjwu marked this pull request as ready for review December 2, 2024 19:18
@jamesjwu jamesjwu added the topic: not user facing topic category label Dec 2, 2024
gm = make_boxed_func(gm)
return gm, {}

def post_compile(self, gm, inputs, cudagraphs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer called after ed's previous refactor

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
constants: Dict[str, torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is good



class CompiledFxGraphConstants:
"""Wrapper class that gets constants from a compiled fx graph"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if this parent class explained the subclass inheritance situation and when this one versus the other got used



@dataclasses.dataclass
class MockFXGraphCacheOutput(OutputCode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what exactly is this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is a one off mock ig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer for this to live in _inductor/output_code.py, as the interface for OutputCode is not settled and likely will change some more as we keep working on it.

config.patch(get_cpp_wrapper_config())
if config.cpp_wrapper
else contextlib.nullcontext()
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all this reformatting very annoying lol

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the comments are blocking, feel free to do them separately

raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)
if torch._inductor.config.freezing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original freezing diff, I checked to see if the gm actually had an frozen params created. Maybe that's a little better? I believe that when the config option is set, freezing is applied unconditionally currently, but maybe there's a future where it's not?

def has_frozen_params(gm: torch.fx.GraphModule) -> bool:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thing is, we can't really find that out unless we run AOTAutograd. So at this stage, the best we can do is look at the config.

)

# TODO: How come cudagraphs could be None here?
# TODO: How come gm is None here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these TODOs now?

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Dec 2, 2024
@jamesjwu jamesjwu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 3, 2024
[ghstack-poisoned]
@eellison eellison removed their request for review December 3, 2024 23:06
[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Dec 4, 2024
[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Dec 4, 2024
@jamesjwu
Copy link
Contributor Author

jamesjwu commented Dec 5, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…#141897)

FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:

- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.

- It refactors the code to be way more clear about the constants we're using and when we're using them.

Basically, there are two possible sets of constants we can grab from the compiled fx graph.

1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.

We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.

This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.

The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.

Will add a specific AOTAutogradCache to confirm we bypass freezing.

Pull Request resolved: pytorch#141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…#141897)

FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:

- It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.

- It refactors the code to be way more clear about the constants we're using and when we're using them.

Basically, there are two possible sets of constants we can grab from the compiled fx graph.

1. If freezing is turned off, we save the constants directly in CompiledFxGraph.
2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.

We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.

This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is.

The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.

Will add a specific AOTAutogradCache to confirm we bypass freezing.

Pull Request resolved: pytorch#141897
Approved by: https://github.com/ezyang, https://github.com/masnesral
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Dec 19, 2024

Hi @jamesjwu @masnesral, Thanks for your PR. We recently meet a issue #143144 which seems related to this PR (or some related changes before this PR).

If freezing is turned on, we save the names of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch.

This seems not a correct assumption. As a example, in the Inductor lowering phase, we may re-layout some constants since a different kernel might be chosen by max-autotune as in:

W_packed_constant = V.graph.add_tensor_constant(W_packed)

In this case, we will add new constant in the CompiledFXGraph but it may not in the GraphModule (we will also delete the original constant which is not used now to save memory). Looking forward to your suggestions for how to resolve this issue. cc @frost-intel @jgong5

@github-actions github-actions bot deleted the gh/jamesjwu/83/head branch January 19, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants