-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add aot_export_joint_with_descriptors and aot_compile_joint_with_descriptors #158715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158715
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 345d04e with merge base 85ee2fb ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| Some descriptors can be quite exotic, so we recommend thinking carefully | ||
| if there is a safe fallback you can apply to descriptors you don't understand. | ||
| For example, you should have some way to handle not finding a particular |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be for cases like desugaring a tensor subclass input into the user fn, and potentially flattening it into one or more tensor and POD inputs into the final graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, tensor subclasses are extra super special, and probably will not work without more user case understanding! See #159005
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock, but the overall structure LGTM.
It would be good to address @wconstab comment on the missing comment though before merging
|
Comment problem was addressed! |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is looking good. I had one question about the tree_spec issue on PlainAOTInput, but i'm not sure if its worth delaying the PR on.
| args, | ||
| fw_compiler: AOTDispatchCompiler, | ||
| bw_compiler: AOTDispatchCompiler, | ||
| kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like it would be better to name this 'flatten_kwargs' and put it next to the flatten bool?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a "flat" kwargs though, it's a dict. IMO, it should go with args because you pass the args along with the kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, i totally misinterpreted this as being kwargs that somehow influence the tree_flatten. I didn't realize they were user kwargs. It makes sense that you would only support user kwargs if you are flattening and otherwise you are asserting they are args-only.
torch/_functorch/aot_autograd.py
Outdated
| full_args = [*params_flat, *buffers_flat, *args] | ||
| in_spec, out_spec = None, None | ||
| if flatten: | ||
| functional_call, out_spec = create_tree_flattened_fn(functional_call, full_args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm. does the new functional_call expect the flattened full_args? if so, i wonder if 'create_tree_flattened_fn' should return full_args, in_spec also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I think that would not be unreasonable, but this is a preexisting function and the convention seems to be to just handle the flattening manually outside.
torch/_functorch/aot_autograd.py
Outdated
| full_args_descs.extend(ParamAOTInput(fqn) for fqn in params_spec) | ||
| full_args_descs.extend(BufferAOTInput(fqn) for fqn in buffers_spec) | ||
| # TODO: it would be better to put pytree information in here | ||
| full_args_descs.extend(PlainAOTInput(i) for i in range(len(full_args) - len(full_args_descs))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, do i understand this..
1.len(full_args) - len(full_args_descs)
- full_args includes params,bufs,flattened_args. full_args_descs only includes params, bufs. so this is just 'len(flattened_args)' which you couldn't directly compute bc we flattened the whole 'full_args'
- We make one PlainAOTInput for each 'flattened arg'.
Help me understand what the point of returning these descriptors are for the flattened case. If I passed a weird input object to my fwd and you pytree flattened it and returned N descriptors, I can't figure out which one corresponds to which of my inputs right? So is this TODO load bearing for the flatten case? Maybe it is fine to fix later just wanted to be clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes and Yes.
In an ideal world, we would also report pytree paths (the TODO) on the PlainInput as well. This would make it easier to tell how the flattened arguments corresponded to the original (non-flattened) arguments. But you can also figure this out manually by flattening your arguments yourself and seeing where they turn up.
| not the intermediate export result. | ||
| TODO: talk carefully about how parameters/buffers work here | ||
| NB: If the passed nn.Module has parameters and buffers on it, we will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so, iiuc this just means that once we start using torch.export frontend with this, we'll have to fix this gap, but it will be straightforward to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.export puts the parameters/buffers on the module, actually, so it's fine.
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
With some claude code assistance I added some UTs too! |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Starting merge as part of PR stack under #159005 |
…riptors (#158715) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #158715 Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan ghstack dependencies: #158624, #158708, #158734
Stack from ghstack (oldest at bottom):
Signed-off-by: Edward Z. Yang ezyang@meta.com