KEMBAR78
Add aot_export_joint_with_descriptors and aot_compile_joint_with_descriptors by ezyang · Pull Request #158715 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jul 20, 2025

[ghstack-poisoned]
@ezyang ezyang requested a review from bdhirsh as a code owner July 20, 2025 02:39
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158715

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 345d04e with merge base 85ee2fb (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 21, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 47dab91
Pull-Request: #158715
@ezyang ezyang requested review from jamesjwu and wconstab July 21, 2025 04:29
@albanD albanD removed their request for review July 21, 2025 20:15
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 21, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 6dea18a
Pull-Request: #158715
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 21, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: b8777d0
Pull-Request: #158715
[ghstack-poisoned]
@ezyang ezyang mentioned this pull request Jul 22, 2025
Some descriptors can be quite exotic, so we recommend thinking carefully
if there is a safe fallback you can apply to descriptors you don't understand.
For example, you should have some way to handle not finding a particular
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be for cases like desugaring a tensor subclass input into the user fn, and potentially flattening it into one or more tensor and POD inputs into the final graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, tensor subclasses are extra super special, and probably will not work without more user case understanding! See #159005

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 25, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: dad5ffc
Pull-Request: #158715
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 25, 2025
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock, but the overall structure LGTM.

It would be good to address @wconstab comment on the missing comment though before merging

@ezyang ezyang added the topic: new features topic category label Jul 25, 2025
@ezyang
Copy link
Contributor Author

ezyang commented Jul 25, 2025

Comment problem was addressed!

@ezyang
Copy link
Contributor Author

ezyang commented Jul 25, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is looking good. I had one question about the tree_spec issue on PlainAOTInput, but i'm not sure if its worth delaying the PR on.

args,
fw_compiler: AOTDispatchCompiler,
bw_compiler: AOTDispatchCompiler,
kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like it would be better to name this 'flatten_kwargs' and put it next to the flatten bool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a "flat" kwargs though, it's a dict. IMO, it should go with args because you pass the args along with the kwargs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i totally misinterpreted this as being kwargs that somehow influence the tree_flatten. I didn't realize they were user kwargs. It makes sense that you would only support user kwargs if you are flattening and otherwise you are asserting they are args-only.

full_args = [*params_flat, *buffers_flat, *args]
in_spec, out_spec = None, None
if flatten:
functional_call, out_spec = create_tree_flattened_fn(functional_call, full_args, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. does the new functional_call expect the flattened full_args? if so, i wonder if 'create_tree_flattened_fn' should return full_args, in_spec also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think that would not be unreasonable, but this is a preexisting function and the convention seems to be to just handle the flattening manually outside.

full_args_descs.extend(ParamAOTInput(fqn) for fqn in params_spec)
full_args_descs.extend(BufferAOTInput(fqn) for fqn in buffers_spec)
# TODO: it would be better to put pytree information in here
full_args_descs.extend(PlainAOTInput(i) for i in range(len(full_args) - len(full_args_descs)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, do i understand this..
1.len(full_args) - len(full_args_descs)

  • full_args includes params,bufs,flattened_args. full_args_descs only includes params, bufs. so this is just 'len(flattened_args)' which you couldn't directly compute bc we flattened the whole 'full_args'
  1. We make one PlainAOTInput for each 'flattened arg'.

Help me understand what the point of returning these descriptors are for the flattened case. If I passed a weird input object to my fwd and you pytree flattened it and returned N descriptors, I can't figure out which one corresponds to which of my inputs right? So is this TODO load bearing for the flatten case? Maybe it is fine to fix later just wanted to be clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and Yes.

In an ideal world, we would also report pytree paths (the TODO) on the PlainInput as well. This would make it easier to tell how the flattened arguments corresponded to the original (non-flattened) arguments. But you can also figure this out manually by flattening your arguments yourself and seeing where they turn up.

not the intermediate export result.
TODO: talk carefully about how parameters/buffers work here
NB: If the passed nn.Module has parameters and buffers on it, we will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, iiuc this just means that once we start using torch.export frontend with this, we'll have to fix this gap, but it will be straightforward to do so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.export puts the parameters/buffers on the module, actually, so it's fine.

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Jul 25, 2025

With some claude code assistance I added some UTs too!

ezyang added a commit that referenced this pull request Jul 25, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 88525ec
Pull-Request: #158715
ezyang added a commit that referenced this pull request Jul 25, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: 88525ec
Pull-Request: #158715
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jul 25, 2025
…riptors

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ghstack-source-id: f20a22b
Pull-Request: #158715
@ezyang
Copy link
Contributor Author

ezyang commented Jul 25, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #159005

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…riptors (#158715)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #158715
Approved by: https://github.com/fmassa, https://github.com/wconstab, https://github.com/xmfan
ghstack dependencies: #158624, #158708, #158734
@github-actions github-actions bot deleted the gh/ezyang/3111/head branch August 25, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants