-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Track descriptors for all inputs/outputs of AOTAutograd traced graph #158624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158624
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6b21fde with merge base 85ee2fb ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: dc941a2 Pull-Request: #158624
|
There are still some bugs to work out but many tests are passing |
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 2b7c5e2 Pull-Request: #158624
| def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"): | ||
| sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1) | ||
| gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why this happened lol
|
Substantially less bugs now! |
| Intuitively, suppose we have: | ||
| def wrapped_graph(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make sure i follow
if i started with tracing module 'foo' that takes (tensor_x, tensor_y) as inputs and returns one tensor_z output
'graph()` would potentially take other inputs too, like lifted parameters (param_0, param_1, tensor_x, tensor_y) and it would also return other stuff like save-for-backward
in_transform(args) would take care to insert parameters and move user inputs accordingly
fin_0(args) would produce input_x?
and there would be no 'fin' that produced param_0 bc that does not correspond to a user input and we don't care about non-user-things?
i think this makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Careful with the indexing! The index i refers to the FX graph arguments (of which there are four). fin_0(args) ignores args and produces param_0. fin_2(args) == args[0]. But yes, you've got the right idea!
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: fae07ca Pull-Request: #158624
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 603c0d1 Pull-Request: #158624
[ghstack-poisoned]
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: d08f95b Pull-Request: #158624
|
Note for reviewers: there are no tests in this PR, they're all in #158708 . I have them split for easier review but I don't mind merging them if people prefer it. |
| out, out_descs = call_and_expect_output_descs(f, args) | ||
| return out | ||
|
|
||
| # TODO: save args_descs/out_descs to the produced FX graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is done in the next PR!
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: c7096d0 Pull-Request: #158624
| correspond to the actual FX graph you get back, to say nothing about the extra | ||
| arguments/outputs for tangents, gradients, etc. Descriptors describe the meaning | ||
| of arguments. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan to enforce these descriptors at compile-time? It seems a bit too easy for them to diverge from the implementation: there's no assertions that they are correct, no tests to cover them. They seem like annotated types, but without a linter to enforce
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are tests in the next PR. However, you are right that we don't cross-check these against metadata that is actually used at runtime to do unpacking. But now we are in a little bit of a pickle: the runtime code is the "source of truth" but the only time we can actually exercise it is at runtime, but we absolutely don't want to be adding a bunch of safety asserts at runtime! There might be some simple things we can replicate asserts for (essentially, do a consistency test between the metadata and the descriptors), but to be honest, I would rather just beef up the tests in the next PR.
One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. TODO: Obviously, we need to store these results somewhere and then test them. My general strategy will be to put it on the meta of the input/output nodes and do some expect tests for them in the AOTAutograd tests. However, this current PR is a good litmus test for whether or not I've nailed all the plumbing correctly, since I made the changes in a maximally BC breaking way (new input arguments / new output return types) which means that if I did anything wrong I should fail quickly. We potentially should introduce some abstractions so that, for example, we don't have to keep passing both `args` and `args_descs` around, but this would have been a more involved refactor, better to do this incrementally later. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: fae07ca Pull-Request: #158624
45b3f46 to
6b21fde
Compare
916b3a7 to
d5c236f
Compare
|
Starting merge as part of PR stack under #158734 |
|
Starting merge as part of PR stack under #158708 |
…sts (#158708) ---- - First, we add a new expanded_def to FX, which will expand the definitions of variables into multiple lines, one per variable definition. This makes extremely long args/return lists much more readable. - Next, we extend this mechanism to also print out descriptors on placeholders and return values, as comments, if available. This is how we will test descriptors. - We update tlparse for AOTAutograd to use this format. - We update expect tests to use this format and update their formats, so you can inspect what it can look at. There may be other tests I should update, open to suggestions. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158708 Approved by: https://github.com/wconstab ghstack dependencies: #158624
Wrapping is load bearing for things that introspect argument signatures, but use of functools.wraps to do this is undesirable as this overrides the name/module of the wrapping function, which is bad for tracking down exactly what code is actually being run at runtime. simple_wraps is like wraps but it doesn't override the name information, so you still get an appropriate printout. To see the stack of all functions wrapping each other, there is now a helper fn_stack. I also make some assertions tighter in the descriptor PR. These didn't catch any bugs but I figure might as well. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158734 Approved by: https://github.com/wconstab ghstack dependencies: #158624, #158708
…riptors (#158715) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158715 Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan ghstack dependencies: #158624, #158708, #158734
…158624) One of the recurring challenges of working with FX graphs produced by AOTAutograd is that there is a very intricate input/output calling convention that is essentially impossible to understand without actually reverse engineering the AOTAutograd code. It is so bad that there is a bit of logic for stashing indices of relevant arguments/outputs in TracingContext so Inductor can figure out what the correct arguments are. This PR introduces the necessary scaffolding to keep track of "descriptors" of every input/output to a (joint) FX graph produced by AOTAutograd. First read through descriptors.py to get a sense for what is available: for inputs, you can figure out if you have a plain input, tangent, parameter, or something more exotic like one of the fields of a subclass or view base. For outputs, you can determine if you have a plain output or grad, or something more exotic like the contents of a mutated input or an intermediate base of several views that were returned. There are two distinct parts of this patch: AOTInput tracking, and AOTOutput tracking. **AOTInput tracking.** The way this works is that AOTAutograd starts of with some Tensor `flat_args` that are the inputs to the graph being traced, and then updates these arguments as it modifies the input calling convention. Anywhere these `args` are passed around, we now add a news argument `args_descs` which is updated in synchrony with args. Add a new arg? Add a new AOTInput to `args_descs`. **AOTOutput tracking.** Originally, I wanted to also add an `outs_descs` analogous to `args_descs` tracking output metadata. However, it is often difficult to compute what the output will be until you're actually tracing the function for real (and are able to peek at the real outputs). So we only compute `outs_desc` when we actually trace. To do this, we change the calling convention of the function we trace to return not just outputs, but a tuple of `outs` and `outs_descs`. Before we bottom out at the `make_fx` invocation, we save `outs_descs` to a nonlocal and bottom out. To actually make use of this information in a useful way, see the next PR. Potentially the two PRs could be combined together but I think it's actually clearer for them to be separate. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158624 Approved by: https://github.com/xmfan
…sts (#158708) ---- - First, we add a new expanded_def to FX, which will expand the definitions of variables into multiple lines, one per variable definition. This makes extremely long args/return lists much more readable. - Next, we extend this mechanism to also print out descriptors on placeholders and return values, as comments, if available. This is how we will test descriptors. - We update tlparse for AOTAutograd to use this format. - We update expect tests to use this format and update their formats, so you can inspect what it can look at. There may be other tests I should update, open to suggestions. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158708 Approved by: https://github.com/wconstab ghstack dependencies: #158624
Wrapping is load bearing for things that introspect argument signatures, but use of functools.wraps to do this is undesirable as this overrides the name/module of the wrapping function, which is bad for tracking down exactly what code is actually being run at runtime. simple_wraps is like wraps but it doesn't override the name information, so you still get an appropriate printout. To see the stack of all functions wrapping each other, there is now a helper fn_stack. I also make some assertions tighter in the descriptor PR. These didn't catch any bugs but I figure might as well. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158734 Approved by: https://github.com/wconstab ghstack dependencies: #158624, #158708
…riptors (#158715) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158715 Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan ghstack dependencies: #158624, #158708, #158734
Stack from ghstack (oldest at bottom):
One of the recurring challenges of working with FX graphs produced by
AOTAutograd is that there is a very intricate input/output calling
convention that is essentially impossible to understand without actually
reverse engineering the AOTAutograd code. It is so bad that there
is a bit of logic for stashing indices of relevant arguments/outputs
in TracingContext so Inductor can figure out what the correct arguments
are.
This PR introduces the necessary scaffolding to keep track of
"descriptors" of every input/output to a (joint) FX graph produced
by AOTAutograd. First read through descriptors.py to get a sense for
what is available: for inputs, you can figure out if you have
a plain input, tangent, parameter, or something more exotic like
one of the fields of a subclass or view base. For outputs, you can
determine if you have a plain output or grad, or something more exotic
like the contents of a mutated input or an intermediate base of several
views that were returned.
There are two distinct parts of this patch: AOTInput tracking, and
AOTOutput tracking.
AOTInput tracking. The way this works is that AOTAutograd starts of
with some Tensor
flat_argsthat are the inputs to the graph beingtraced, and then updates these arguments as it modifies the input
calling convention. Anywhere these
argsare passed around, we now add anews argument
args_descswhich is updated in synchrony with args. Adda new arg? Add a new AOTInput to
args_descs.AOTOutput tracking. Originally, I wanted to also add an
outs_descsanalogous to
args_descstracking output metadata. However, it isoften difficult to compute what the output will be until you're actually
tracing the function for real (and are able to peek at the real
outputs). So we only compute
outs_descwhen we actually trace. To dothis, we change the calling convention of the function we trace to
return not just outputs, but a tuple of
outsandouts_descs. Beforewe bottom out at the
make_fxinvocation, we saveouts_descsto anonlocal and bottom out.
To actually make use of this information in a useful way, see the next PR. Potentially the two PRs could be combined together but I think it's actually clearer for them to be separate.
Signed-off-by: Edward Z. Yang ezyang@meta.com