KEMBAR78
Add support for tracing vmap in pre-dispatch export by tugsbayasgalan · Pull Request #154650 · pytorch/pytorch · GitHub
Skip to content

Conversation

@tugsbayasgalan
Copy link
Contributor

@tugsbayasgalan tugsbayasgalan commented May 29, 2025

Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

The implementation strategy is:

  1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
  2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.

Test Plan: CI

Differential Revision: D75623875

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented May 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154650

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8459179 with merge base 5ee464d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our conclusion from the meeting on Tuesday was:

  1. Yes, we're going to put all the API calls into the graph
  2. We should only interpose on these API calls when non-strict export is on. These shouldn't go through regular torch_function, because they are private APIs.
  3. It might be easier to do this by creating a python function wrapper around e.g. add_batch_dim, and then putting the torch_function handler and export checks into said function. This may require you to add these new functions to Dynamo skiplists to not break the Dynamo side of things

pytorch-bot bot pushed a commit that referenced this pull request Jul 23, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Jul 23, 2025
Summary:
Pull Request resolved: pytorch#154650

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@tugsbayasgalan tugsbayasgalan requested a review from zou3519 August 19, 2025 19:04
tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Aug 20, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Aug 20, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Aug 20, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Aug 20, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

tugsbayasgalan added a commit to tugsbayasgalan/pytorch that referenced this pull request Aug 20, 2025
Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give me some docs about what is going on in proxy_tensor.py

Summary:

ONNX team ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

Test Plan: CI

Reviewed By: zou3519

Differential Revision: D75623875
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75623875

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.

Test Plan: CI

Differential Revision: D75623875

Pull Request resolved: pytorch#154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants