KEMBAR78
Track descriptors for all inputs/outputs of AOTAutograd traced graph by ezyang · Pull Request #158624 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 18, 2025

Stack from ghstack (oldest at bottom):

One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code. It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd. First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base. For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

AOTInput tracking. The way this works is that AOTAutograd starts of
with some Tensor flat_args that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention. Anywhere these args are passed around, we now add a
news argument args_descs which is updated in synchrony with args. Add
a new arg? Add a new AOTInput to args_descs.

AOTOutput tracking. Originally, I wanted to also add an outs_descs
analogous to args_descs tracking output metadata. However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs). So we only compute outs_desc when we actually trace. To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of outs and outs_descs. Before
we bottom out at the make_fx invocation, we save outs_descs to a
nonlocal and bottom out.

To actually make use of this information in a useful way, see the next PR. Potentially the two PRs could be combined together but I think it's actually clearer for them to be separate.

Signed-off-by: Edward Z. Yang ezyang@meta.com

[ghstack-poisoned]
@ezyang ezyang requested a review from bdhirsh as a code owner July 18, 2025 03:29
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158624

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6b21fde with merge base 85ee2fb (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Jul 18, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: dc941a2
Pull-Request: #158624
@ezyang ezyang requested a review from jamesjwu July 18, 2025 03:29
@ezyang
Copy link
Contributor Author

ezyang commented Jul 18, 2025

There are still some bugs to work out but many tests are passing

[ghstack-poisoned]
@ezyang ezyang requested a review from Chillee as a code owner July 18, 2025 14:26
ezyang added a commit that referenced this pull request Jul 18, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 2b7c5e2
Pull-Request: #158624
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why this happened lol

@ezyang
Copy link
Contributor Author

ezyang commented Jul 18, 2025

Substantially less bugs now!

@albanD albanD removed their request for review July 18, 2025 14:27
Intuitively, suppose we have:
def wrapped_graph(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make sure i follow

if i started with tracing module 'foo' that takes (tensor_x, tensor_y) as inputs and returns one tensor_z output

'graph()` would potentially take other inputs too, like lifted parameters (param_0, param_1, tensor_x, tensor_y) and it would also return other stuff like save-for-backward

in_transform(args) would take care to insert parameters and move user inputs accordingly

fin_0(args) would produce input_x?

and there would be no 'fin' that produced param_0 bc that does not correspond to a user input and we don't care about non-user-things?

i think this makes sense!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful with the indexing! The index i refers to the FX graph arguments (of which there are four). fin_0(args) ignores args and produces param_0. fin_2(args) == args[0]. But yes, you've got the right idea!

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 18, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: fae07ca
Pull-Request: #158624
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 20, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 603c0d1
Pull-Request: #158624
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 21, 2025
[ghstack-poisoned]
ezyang added 2 commits July 23, 2025 11:59
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 23, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: d08f95b
Pull-Request: #158624
@ezyang
Copy link
Contributor Author

ezyang commented Jul 23, 2025

Note for reviewers: there are no tests in this PR, they're all in #158708 . I have them split for easier review but I don't mind merging them if people prefer it.

out, out_descs = call_and_expect_output_descs(f, args)
return out

# TODO: save args_descs/out_descs to the produced FX graph
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done in the next PR!

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 23, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: c7096d0
Pull-Request: #158624
correspond to the actual FX graph you get back, to say nothing about the extra
arguments/outputs for tangents, gradients, etc. Descriptors describe the meaning
of arguments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan to enforce these descriptors at compile-time? It seems a bit too easy for them to diverge from the implementation: there's no assertions that they are correct, no tests to cover them. They seem like annotated types, but without a linter to enforce

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are tests in the next PR. However, you are right that we don't cross-check these against metadata that is actually used at runtime to do unpacking. But now we are in a little bit of a pickle: the runtime code is the "source of truth" but the only time we can actually exercise it is at runtime, but we absolutely don't want to be adding a bunch of safety asserts at runtime! There might be some simple things we can replicate asserts for (essentially, do a consistency test between the metadata and the descriptors), but to be honest, I would rather just beef up the tests in the next PR.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jul 24, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

ezyang added a commit that referenced this pull request Jul 24, 2025
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

TODO: Obviously, we need to store these results somewhere and then test
them.  My general strategy will be to put it on the meta of the
input/output nodes and do some expect tests for them in the AOTAutograd
tests.  However, this current PR is a good litmus test for whether or
not I've nailed all the plumbing correctly, since I made the changes in
a maximally BC breaking way (new input arguments / new output return
types) which means that if I did anything wrong I should fail quickly.
We potentially should introduce some abstractions so that, for example,
we don't have to keep passing both `args` and `args_descs` around, but
this would have been a more involved refactor, better to do this
incrementally later.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: fae07ca
Pull-Request: #158624
[ghstack-poisoned]
@ezyang ezyang force-pushed the gh/ezyang/3109/head branch from 45b3f46 to 6b21fde Compare July 24, 2025 22:06
@ezyang ezyang force-pushed the gh/ezyang/3109/base branch from 916b3a7 to d5c236f Compare July 24, 2025 22:06
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158734

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #158708

pytorchmergebot pushed a commit that referenced this pull request Jul 25, 2025
…sts (#158708)

----

- First, we add a new expanded_def to FX, which will expand the
  definitions of variables into multiple lines, one per variable
  definition.  This makes extremely long args/return lists much
  more readable.

- Next, we extend this mechanism to also print out descriptors on
  placeholders and return values, as comments, if available.  This
  is how we will test descriptors.

- We update tlparse for AOTAutograd to use this format.

- We update expect tests to use this format and update their formats,
  so you can inspect what it can look at.  There may be other tests
  I should update, open to suggestions.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #158708
Approved by: https://github.com/wconstab
ghstack dependencies: #158624
pytorchmergebot pushed a commit that referenced this pull request Jul 25, 2025
Wrapping is load bearing for things that introspect argument signatures,
but use of functools.wraps to do this is undesirable as this overrides
the name/module of the wrapping function, which is bad for tracking down
exactly what code is actually being run at runtime.  simple_wraps is
like wraps but it doesn't override the name information, so you still
get an appropriate printout.  To see the stack of all functions wrapping
each other, there is now a helper fn_stack.

I also make some assertions tighter in the descriptor PR.  These didn't
catch any bugs but I figure might as well.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158734
Approved by: https://github.com/wconstab
ghstack dependencies: #158624, #158708
pytorchmergebot pushed a commit that referenced this pull request Jul 25, 2025
…riptors (#158715)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158715
Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan
ghstack dependencies: #158624, #158708, #158734
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…158624)

One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code.  It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.

This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd.  First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base.  For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.

There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.

**AOTInput tracking.**  The way this works is that AOTAutograd starts of
with some Tensor `flat_args` that are the inputs to the graph being
traced, and then updates these arguments as it modifies the input
calling convention.  Anywhere these `args` are passed around, we now add a
news argument `args_descs` which is updated in synchrony with args.  Add
a new arg?  Add a new AOTInput to `args_descs`.

**AOTOutput tracking.**  Originally, I wanted to also add an `outs_descs`
analogous to `args_descs` tracking output metadata.  However, it is
often difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs).  So we only compute `outs_desc` when we actually trace.  To do
this, we change the calling convention of the function we trace to
return not just outputs, but a tuple of `outs` and `outs_descs`.  Before
we bottom out at the `make_fx` invocation, we save `outs_descs` to a
nonlocal and bottom out.

To actually make use of this information in a useful way, see the next PR. Potentially the two PRs could be combined together but I think it's actually clearer for them to be separate.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158624
Approved by: https://github.com/xmfan
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…sts (#158708)

----

- First, we add a new expanded_def to FX, which will expand the
  definitions of variables into multiple lines, one per variable
  definition.  This makes extremely long args/return lists much
  more readable.

- Next, we extend this mechanism to also print out descriptors on
  placeholders and return values, as comments, if available.  This
  is how we will test descriptors.

- We update tlparse for AOTAutograd to use this format.

- We update expect tests to use this format and update their formats,
  so you can inspect what it can look at.  There may be other tests
  I should update, open to suggestions.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #158708
Approved by: https://github.com/wconstab
ghstack dependencies: #158624
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
Wrapping is load bearing for things that introspect argument signatures,
but use of functools.wraps to do this is undesirable as this overrides
the name/module of the wrapping function, which is bad for tracking down
exactly what code is actually being run at runtime.  simple_wraps is
like wraps but it doesn't override the name information, so you still
get an appropriate printout.  To see the stack of all functions wrapping
each other, there is now a helper fn_stack.

I also make some assertions tighter in the descriptor PR.  These didn't
catch any bugs but I figure might as well.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158734
Approved by: https://github.com/wconstab
ghstack dependencies: #158624, #158708
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…riptors (#158715)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158715
Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan
ghstack dependencies: #158624, #158708, #158734
@github-actions github-actions bot deleted the gh/ezyang/3109/head branch August 25, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants