KEMBAR78
[ONNX] Support symbolic arguments in onnx exporter by titaiwangms · Pull Request #157734 · pytorch/pytorch · GitHub
Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Jul 7, 2025

Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:

class M(torch.nn.Module):
    def forward(self, a, x):
        return a + torch.tensor(1) + x

op = torch.onnx.export(M(), (1, torch.ones(2)), 
                       dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}), 
                       dynamo=True, report=True)

symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: #99534 where we tried to get the FX directly from dynamo export. Thus, _remove_non_tensor is deleted from args processing.

NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.

The test test_constant_argument_user_input_is_omitted_in_onnx_graph needs #157719

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jul 7, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157734

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ce23365 with merge base 2e14069 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@titaiwangms titaiwangms added the topic: bug fixes topic category label Jul 7, 2025
@titaiwangms titaiwangms changed the title [ONNX] Support symbolic arguments in onnx api [ONNX] Support symbolic arguments in onnx exporter Jul 7, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2025
@titaiwangms
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #157734, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

if isinstance(spec.arg, graph_signature.ConstantArgument):
# If dynamic is set to a constant input, it becomes a
# symbolic argument, which is not a tensor.
if isinstance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't have tensor_meta though. Is there other info we might want from them?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just say they are symbolic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a string? Currently, io_spec supports onlt tensor_meta.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only for printing iirc? If so it's fine as long as we are able to display them in the report.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@atalman
Copy link
Contributor

atalman commented Jul 25, 2025

Hi @titaiwangms looks like this is broke amazon linux 2023 test: https://github.com/pytorch/test-infra/actions/runs/16525830777/job/46738788918 when numpy not installed:

 python3 .ci/pytorch/smoke_test/smoke_test.py --package torchonly
/usr/local/lib64/python3.9/site-packages/torch/_subclasses/functional_tensor.py:279: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:82.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
torch: 2.9.0.dev20250725+cu128
ATen/Parallel:
	at::get_num_threads() : 8
	at::get_num_interop_threads() : 16
OpenMP 201511 (a.k.a. OpenMP 4.5)
	omp_get_max_threads() : 8
Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications
	mkl_get_max_threads() : 8
Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d)
std::thread::hardware_concurrency() : 16
Environment variables:
	OMP_NUM_THREADS : [not set]
	MKL_NUM_THREADS : [not set]
ATen parallel backend: OpenMP

Traceback (most recent call last):
Skip version check for channel None as stable version is None
  File "/home/ec2-user/actions-runner/_work/test-infra/test-infra/test-infra/.github/scripts/run_with_env_secrets.py", line 102, in <module>
Testing smoke_test_conv2d
Testing smoke_test_linalg on cpu
Numpy check skipped. Numpy is not installed.
Testing smoke_test_compile for cuda and torch.float16
Traceback (most recent call last):
  File "/pytorch/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 509, in <module>
    main()
  File "/pytorch/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 498, in main
    smoke_test_cuda(
  File "/pytorch/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 256, in smoke_test_cuda
    smoke_test_compile("cuda" if torch.cuda.is_available() else "cpu")
  File "/pytorch/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 366, in smoke_test_compile
    x_pt2 = torch.compile(foo)(x)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1595, in __call__
    result = self._torchdynamo_orig_backend(
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1353, in __call__
    result = self._inner_convert(
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 682, in __call__
    result = _compile(
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1233, in _compile
    raise InternalTorchDynamoError(
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1172, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib64/python3.9/site-packages/torch/_utils_internal.py", line 92, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 858, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 300, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 818, in transform
    tracer.run()
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 3528, in run
    super().run()
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1372, in run
    while self.step():
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1276, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2189, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/variables/torch.py", line 1436, in call_function
    special_handler = self._get_handlers().get(self.value)
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/variables/torch.py", line 476, in _get_handlers
    @register(*tracing_state_functions())
  File "/usr/local/lib64/python3.9/site-packages/torch/_dynamo/variables/torch.py", line 193, in tracing_state_functions
    torch.onnx.is_in_onnx_export: False,
  File "/usr/local/lib64/python3.9/site-packages/torch/__init__.py", line 2734, in __getattr__
    return importlib.import_module(f".{name}", __name__)
  File "/usr/lib64/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "/usr/local/lib64/python3.9/site-packages/torch/onnx/__init__.py", line 51, in <module>
    from ._internal.exporter._onnx_program import ONNXProgram
  File "/usr/local/lib64/python3.9/site-packages/torch/onnx/_internal/exporter/_onnx_program.py", line 18, in <module>
    import numpy as np
torch._dynamo.exc.InternalTorchDynamoError: ModuleNotFoundError: No module named 'numpy'
    main()
  File "/home/ec2-user/actions-runner/_work/test-infra/test-infra/test-infra/.github/scripts/run_with_env_secrets.py", line 98, in main
    run_cmd_or_die(f"docker exec -t {container_name} /exec")
  File "/home/ec2-user/actions-runner/_work/test-infra/test-infra/test-infra/.github/scripts/run_with_env_secrets.py", line 39, in run_cmd_or_die
    raise RuntimeError(f"Command {cmd} failed with exit code {exit_code}")
RuntimeError: Command docker exec -t 9a2177c2bb81d362fa38fb0be862506bc86801f552d91ff0d7380aa9ca4b11a7 /exec failed with exit code 1
Error: Process completed with exit code 1.

@atalman
Copy link
Contributor

atalman commented Jul 25, 2025

@pytorchmergebot revert -c nosignal -m "Broke test on amazon linux 2023 when numpy not installed"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 157734 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 08e9dd280f497fc820e35c458c843dc44f0282c6 returned non-zero exit code 1

Auto-merging test/onnx/exporter/test_api.py
Auto-merging torch/onnx/__init__.py
CONFLICT (content): Merge conflict in torch/onnx/__init__.py
error: could not revert 08e9dd280f4... [ONNX] Support symbolic arguments in onnx exporter (#157734)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job



def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue:
def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms we can hide the numpy import here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the help.

pytorchmergebot pushed a commit that referenced this pull request Jul 27, 2025
One should not expect numpy to be there during onnx import
Forward fix for : #157734
Added regression test to `test_without_numpy` function

Test plan: Run `python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch; import torch.onnx"` with/without this fix
Pull Request resolved: #159177
Approved by: https://github.com/atalman, https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/cyyever, https://github.com/Skylion007, https://github.com/andrewboldi
yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
One should not expect numpy to be there during onnx import
Forward fix for : #157734
Added regression test to `test_without_numpy` function

Test plan: Run `python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch; import torch.onnx"` with/without this fix
Pull Request resolved: #159177
Approved by: https://github.com/atalman, https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/cyyever, https://github.com/Skylion007, https://github.com/andrewboldi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants