KEMBAR78
[ONNX] Add complex constant support by titaiwangms · Pull Request #138279 · pytorch/pytorch · GitHub
Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Oct 17, 2024

Transform complex python constant to float representation as well, like what we have with tensors.

PS: I find it's not reasonable to add "complex->float" in IR side, so I put it here.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Oct 17, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138279

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 48bcd3d with merge base 8231180 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 18, 2024
@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: new features topic category labels Oct 21, 2024
@titaiwangms titaiwangms linked an issue Oct 21, 2024 that may be closed by this pull request
@titaiwangms
Copy link
Collaborator Author

If I try with the following repro to prove the fix:

import torch

class MulModule(torch.nn.Module):
    def forward(self, x, y):
        return torch.ops.aten.mul(x, y)

# Example usage with complex inputs
x = torch.tensor([[1.0 + 2.0j, 3.0 + 4.0j], 
                           [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64)

# Example 1: Non-tensor input (scalar)
y = 2 + 3j

onnx_program = torch.onnx.export(MulModule(), (x, y,), "slice.onnx", dynamo=True, report=True)

I get

[torch.onnx] Obtain model graph for `MulModule()` with `torch.export.export`...
[torch.onnx] Obtain model graph for `MulModule()` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `MulModule()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MulModule()` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `MulModule()` with Torch Script...
[torch.onnx] Obtain model graph for `MulModule()` with Torch Script... ❌
[torch.onnx] Obtain model graph for `MulModule()` with internal Dynamo apis...
[torch.onnx] Obtain model graph for `MulModule()` with internal Dynamo apis... ❌
Traceback (most recent call last):
  File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_capture_strategies.py", line 110, in __call__
    exported_program = self._capture(model, args, kwargs, dynamic_shapes)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/onnx/_internal/exporter/_capture_strategies.py", line 145, in _capture
    return torch.export.export(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
           ^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 1018, in wrapper
    raise e
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 991, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/exported_program.py", line 122, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 1974, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 1238, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 1347, in _strict_export_lower_to_aten_ir
    aten_export_artifact = lower_to_aten_callback(
                           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/export/_trace.py", line 647, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/titaiwang/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 171, in inner
    assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
AssertionError

@justinchuby
Copy link
Collaborator

I see. Just move it to inside the forward function?

@titaiwangms
Copy link
Collaborator Author

I see. Just move it to inside the forward function?

Thanks! A test can be added!

@titaiwangms
Copy link
Collaborator Author

@justinchuby PTAL

titaiwangms and others added 2 commits October 22, 2024 09:52
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 22, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

SamGinzburg pushed a commit that referenced this pull request Oct 28, 2024
Transform complex python constant to float representation as well, like what we have with tensors.

PS: I find it's not reasonable to add "complex->float" in IR side, so I put it here.
Pull Request resolved: #138279
Approved by: https://github.com/justinchuby

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ONNX] Handle complex python constants

5 participants