KEMBAR78
Introduce stages to aot_dispatch by ezyang · Pull Request #158213 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 14, 2025

Stack from ghstack (oldest at bottom):

The starting point for this refactor is that I need access to the fully
general joint graph representation in an export-like interface, but I
then subsequently need a way to feed this joint graph into the rest of
the compilation pipeline so I can get an actual callable that I can run
once I've finished modifying it. Previously, people had added export
capabilities to AOTAutograd by having an export flag that toggled what
exactly the functions return and triggering aot_dispatch to go to a
different "export" implementation, but I've found this difficult to
understand and has lead to a bit of duplicate code for the export path.

So the idea here is to reorganize the structure of the function calls in AOTAutograd. Here, it is helpful to first describe how things used to work:

  • Start with aot_autograd.py top level functions like aot_function, _aot_export_function and aot_module_simplified. These call:
    • create_aot_dispatcher_function. This does a bunch of stuff (forward metadata collection) and adds many context managers. This calls:
      • One of aot_dispatch_base, aot_dispatch_export or aot_dispatch_autograd, which:
        • Call aot_dispatch_autograd_graph or aot_dispatch_base_graph to actually do the graph capture
        • Do some base/export/autograd specific post-processing on the graph

Notice the pattern of nested function invocations means that there is no way to easily get the graph capture result from the autograd case; furthermore, the export path is "bolted" on to force the entire chain of functions to have a different return result than normal, and no way to resume the rest of the post-processing to actually get a callable.

Here is the new structure:

  • Start with aot_autograd.py top level functions like aot_function, _aot_export_function and aot_module_simplified. These now orchestrate this top level flow:
    • Start a context manager (stack); this stateful context block takes care of all of the nested context managers which originally necessitated the nested call structure
    • Call create_aot_state to do initial setup and setup all the context managers on stack. These context managers do NOT exit upon return of this.
    • Call aot_stage1_graph_capture to do the graph capture
    • Call aot_stage2_compile or aot_stage2_export depending on what postprocessing you want

With this new structure, it's now possible (although not done in this PR) to return the graph after aot_stage1_graph_capture and do something with it, before running aot_stage2_compile to finish the job.

Signed-off-by: Edward Z. Yang ezyang@meta.com

[ghstack-poisoned]
@ezyang ezyang requested a review from bdhirsh as a code owner July 14, 2025 03:45
ezyang added a commit that referenced this pull request Jul 14, 2025
The starting point for this refactor is that I need access to the fully
general joint graph representation in an export-like interface, but I
then subsequently need a way to feed this joint graph into the rest of
the compilation pipeline so I can get an actual callable that I can run
once I've finished modifying it.  Previously, people had added export
capabilities to AOTAutograd by having an export flag that toggled what
exactly the functions return and triggering aot_dispatch to go to a
different "export" implementation, but I've found this difficult to
understand and has lead to a bit of duplicate code for the export path.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 2cbfa33
Pull-Request: #158213
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158213

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated Failure

As of commit 7cfc200 with merge base 4b9a6f7 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

"""

# Old name for now to avoid messing with stats. Also, note this is pushed
# on the stack, so it extends BEYOND this function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will, in essence, increase the average compile time logged to aot_dispatcher_function right? In the sense that, if there was significant compile time after this function returned, it would now be recorded under create_aot_dispatcher_function.

To be clear, I don't think there's actually any work being done after this function returns, but that's just in theory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that I end the ExitStack context manager at the same point in time it would have previously ended when we returned from the recursive calls. You should be able to verify this.

Copy link
Contributor

@jamesjwu jamesjwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The naming of these files really don't make that much sense to me anymore; I don't really know what a jit compile runtime wrapper is, or how it's different from runtime_wrappers.py.

I feel like it would have made more sense to call the files
_aot_autograd/graph_capture.py, compile.py, export.py

The wrapper nomenclature made sense when we had a single big function that made a million closures, but now that we've pipelined it into inputs and outputs we might just want to split the stages.



@dataclass
class AOTState:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It almost feels like AOTState itself should either be a context manager that you can grab information from, or a functional state that you return from each stage like:

(aot_state, aot_graph_capture) = aot_stage1()

This in between where aot_state is mutable as a side effect of calling the function is... okay I guess lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a huge antipattern to pass in AOTState as an argument to functions. Actually, I wanted to make it a proper object with methods on it for stages (so you access the state via self) but this would have required a lot of reindenting and code motion, so I didn't do it to make the PR easier to review.

A functional AOTState will be error prone, because there are a lot of places where we were previously relying on variable shadowing to ensure you get the "latest" version of any given variable. So I actually had a more functional version of this class that I replaced with the mutating one to stop people from accidentally using stale versions of variables.



@dataclass
class AOTGraphCapture: # Produced by aot_stage1_graph_capture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect fw_metadata to be in here, no? Isn't it generated during graph capture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weeeell it's already in AOTState, so I didn't put it in this struct lol. We could have this struct have a pointer to AOTState so you have everything available in the same place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I change my mind, fw_metadata gets changed by wrappers, so it being in AOTState makes more sense than putting it in AOTGraphCapture, because it's not a pure result of AOTGraphCapture


@dataclass
class AOTGraphCapture: # Produced by aot_stage1_graph_capture
# AOTAutograd typically operates by taking complicated graphs and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not completely true in that there are wrappers that happen after graph capture too, tehre's just some wrappers that happen even before graph capture. Hmm, maybe these wrappers in the new model should be in AOTState instead?

AOTDedupWrapper and AOTSyntheticBaseWrapper used to happen right before calling aot_dispatch_{base|autograd}, but also come with a post_compile after the inner compiler returns. So they're called as the first step in aot_dispatch_base. Not sure exactly where they belong in this model.

Other wrappers, like FunctionalizedRngWrapper and RuntimeWrapper itself, are created only for aot_stage2_compile. I don't think anything needs to track these, since they're all compile specific, so they can just be handled in aot_stage2_compile_{base|autograd}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm happy to delete this comment for clarity, I think this comment is directionally correct (and it is certainly how @bdhirsh and I conceptualized AOTAutograd when we originally worked on v2 of the implementation). For example, you focus on "when" AOTDedupWrapper happens. But intuitively, what it actually does is it takes a graph where there some inputs that may occur multiple times, and gives us a new graph where the input only occurs once, so that it is safe to run autograd on it. There are things that have to happen before and after here: before we have to create the new function; after we have to apply a runtime wrapper that will eliminate the duplicate arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I actually meant to make this comment on the wrapper itself and was reasoning about where these wrappers belong.

I think the reason I'm a bit hung up on the timing is precisely because wrappers execute code both before and after the graph capture, so it was a little confusing that it was being returned by AOTGraphCapture itself. But it makes sense why it's attached to the return value; you need to track the post compiles of each of the wrappers. I think the awkward part is, because they are wrappers, the actual precise ordering of wraps is like:

AOTDedupWrapper -> aot_dispatch wrappers (pre_compile) -> fw/bw_compiler output -> aot_dispatch wrappers (post_compile) -> AOTDedupWrapper (post compile)

So in essence, even though stage2 is taking in AOTGraphCapture and these wrappers are part of AOTGraphCapture, it's basically just passing them through without applying them until the very end.

But honestly that level of detail of understanding is probably not that important to fully encapsulate in the data model or this comment, I just needed to wrap my head around it. It's probably fine as is.

@ezyang
Copy link
Contributor Author

ezyang commented Jul 15, 2025

Nit: The naming of these files really don't make that much sense to me anymore; I don't really know what a jit compile runtime wrapper is, or how it's different from runtime_wrappers.py.

The naming never made sense to me even before this refactor 🤣 . But I can only do one thing at a time. I can queue up file renames after this PR... if you approve it :P

[ghstack-poisoned]
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 15, 2025
@ezyang
Copy link
Contributor Author

ezyang commented Jul 15, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: inductor-rocm / rocm-py3.10-inductor / test (inductor, 2, 2, linux.rocm.gpu.2), pull / linux-jammy-py3-clang12-mobile-build / build, pull / cuda12.8-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu, unstable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158251

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158319

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158319

pytorchmergebot pushed a commit that referenced this pull request Jul 16, 2025
…nd functions to frontend_utils (#158251)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #158251
Approved by: https://github.com/jamesjwu
ghstack dependencies: #158149, #158150, #158173, #158176, #158213
pytorchmergebot pushed a commit that referenced this pull request Jul 16, 2025
Also a small amount of extra code cleanup.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158319
Approved by: https://github.com/jingsh
ghstack dependencies: #158149, #158150, #158173, #158176, #158213, #158251
@github-actions github-actions bot deleted the gh/ezyang/3101/head branch August 16, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants