KEMBAR78
[ONNX] Remove type promotion rule for pow by justinchuby · Pull Request #139527 · pytorch/pytorch · GitHub
Skip to content

Conversation

@justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Nov 1, 2024

ONNX supports different input types in Pow, so type promotion is not needed.

The resulting graph is the following:

ONNXProgram(
    model=
        <
            ir_version=9,
            opset_imports={'': 18, 'pkg.onnxscript.torch_lib.common': 1},
            producer_name='pytorch',
            producer_version='2.6.0a0+git59a1af5',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT16,[3]>
            ),
            outputs=(
                %"pow_1"<FLOAT16,[3]>
            ),
        ) {
            0 |  # node_Constant_0
                 %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name=None)}
            1 |  # node_Pow_1
                 %"pow_1"<FLOAT16,[3]> ⬅️ ::Pow(%"x", %"val_0")
            return %"pow_1"<FLOAT16,[3]>
        }
...
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, x: "f16[3]"):
                     # File: /workspace/pytorch/test/onnx/exporter/test_small_models_e2e.py:53 in forward, code: return x**2.0
                    pow_1: "f16[3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2.0);  x = None
                    return (pow_1,)
            
        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='pow_1'), target=None)])
        Range constraints: {}

)

ONNX supports different input types in Pow, so type promotion is not needed
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139527

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c4679a1 with merge base d338499 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Nov 1, 2024
@justinchuby justinchuby added module: onnx Related to torch.onnx topic: improvements topic category and removed release notes: onnx torch.onnx related changes that should show up in the release notes fx labels Nov 1, 2024
@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Nov 1, 2024
@justinchuby justinchuby requested review from titaiwangms and xadupre and removed request for shubhambhokare1, titaiwangms and wschin November 1, 2024 22:46
@justinchuby
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@justinchuby justinchuby deleted the justinchu/type-promotion-pow branch November 2, 2024 02:38
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
ONNX supports different input types in Pow, so type promotion is not needed.

The resulting graph is the following:

```py
ONNXProgram(
    model=
        <
            ir_version=9,
            opset_imports={'': 18, 'pkg.onnxscript.torch_lib.common': 1},
            producer_name='pytorch',
            producer_version='2.6.0a0+git59a1af5',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT16,[3]>
            ),
            outputs=(
                %"pow_1"<FLOAT16,[3]>
            ),
        ) {
            0 |  # node_Constant_0
                 %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name=None)}
            1 |  # node_Pow_1
                 %"pow_1"<FLOAT16,[3]> ⬅️ ::Pow(%"x", %"val_0")
            return %"pow_1"<FLOAT16,[3]>
        }
...
    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, x: "f16[3]"):
                     # File: /workspace/pytorch/test/onnx/exporter/test_small_models_e2e.py:53 in forward, code: return x**2.0
                    pow_1: "f16[3]" = torch.ops.aten.pow.Tensor_Scalar(x, 2.0);  x = None
                    return (pow_1,)

        Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='pow_1'), target=None)])
        Range constraints: {}

)
```

Pull Request resolved: pytorch#139527
Approved by: https://github.com/titaiwangms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants