KEMBAR78
[ONNX] Support from dynamic_shapes to dynamic_axes when torch.onnx.export(fallback=True) is triggered by titaiwangms · Pull Request #139532 · pytorch/pytorch · GitHub
Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Nov 1, 2024

Fixes #139320

Summary:

(1) Add _rename_dynamic_shapes_with_model_inputs for dynamic_shapes to play along with input_names

  • Use model forward signature to rename dynamic_shapes when dynamic_shapes is not nested and dynamic_shapes is directly using the customized name. This solves the issue that torch.export.export expects dynamic_shapes only uses the model input names.
  • If the dynamic_shapes is nested, we do nothing.

(2) Add _from_dynamic_shapes_to_dynamic_axes for fallback

  • We flatten dynamic_shapes with leaf defined _pytree.tree_leaves()
    * If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by _rename_dynamic_shapes_with_model_inputs already.
  • If a dynamic_shapes is provided, input_names is required to assign the names, because dynamic_axes needs it.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Nov 1, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139532

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit bdfe02d with merge base e429a3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@titaiwangms titaiwangms added the topic: bug fixes topic category label Nov 1, 2024
@titaiwangms titaiwangms requested a review from xadupre November 1, 2024 23:13
@justinchuby justinchuby self-assigned this Nov 1, 2024
# It doesn not specify input names if it's a tuple
return dynamic_shapes

sig = _signature(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this call ever raise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when the dynamic_shapes is dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically dynamic_shapes and args, kwargs are exactly the same (every tensor is replaced by a dictionary of dynamic dimensions) but only a subset of it is supported by torchscript. So i would expect the function to fail if the inputs are not flat unless the module is wrapped into a module doing that. But i don't think we should do it for the users before calling the fallback. However an error message suggesting to do so would be great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing makes sense, and it's easier to implement. I have tried the following example:

import torch
import onnx

class Foo(torch.nn.Module):
    def forward(self, x):
        (a0, a1), (b0, b1), (c0, c1, c2) = x
        return a0 + a1 + b0 + b1 + c0 + c1 + c2

f = Foo()
inputs = (
    (1, 2),
    (
        torch.randn(4, 4),
        torch.randn(4, 4),
    ),
    (
        torch.randn(4, 4),
        torch.randn(4, 4),
        torch.randn(4, 4),
    ),
)

input_names = ["a", "b", "c", "d", "e", "f", "g"]
dynamic_axes = {
    "c": {0: "c_dim_0", 1: "c_dim_1"},
    "e": {0: "e_dim_0", 1: "e_dim_1"},
    "f": {0: "f_dim_0", 1: "f_dim_1"},
}

torch.onnx.export(f, (inputs,), "nested.onnx", dynamic_axes=dynamic_axes, input_names=input_names, verbose=True)
onnx_model = onnx.load("nested.onnx")
print(onnx_model.graph.input)

I think this is the only way that nested dynamic_axes works, but it's also kind of awkward that model forward input is not the same as in input_names. Users need to flatten ahead in dynamic_axes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function has been changed to only work when input is not nested (len(onnx inputs) == len(torch inputs)). When it's nested, users are expected to provide the correct names (match model.forward) of dynamic_shapes, or tuple (honestly, I think users would use tuple when the inputs are nested.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this to the function docstring? Just something that makes it clear when this functions works and when it doesn't

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 5, 2024
@titaiwangms
Copy link
Collaborator Author

titaiwangms commented Nov 6, 2024

Overall, I think there are a few questions that we need to clarify before adding flattened dynamic_shapes support.

(1) When args is nested, what is our expectation on input_names. for example:

def forward(self, input: tuple[torch.tensor, torch.tensor])  # there are two elements in the tuple

# one to one mapping to forward
input_names = ["input"]
# or flattened to map ONNX nodes
input_names = ["input_a", "input_b"]

(2) dynamic_axes requires users to flatten inputs ahead. For example:

import torch
import onnx

class Foo(torch.nn.Module):
    def forward(self, x):
        (a0, a1), (b0, b1), (c0, c1, c2) = x
        return a0 + a1 + b0 + b1 + c0 + c1 + c2

f = Foo()
inputs = (
    (1, 2),
    (
        torch.randn(4, 4),
        torch.randn(4, 4),
    ),
    (
        torch.randn(4, 4),
        torch.randn(4, 4),
        torch.randn(4, 4),
    ),
)

input_names = ["a", "b", "c", "d", "e", "f", "g"]
dynamic_axes = {
    "c": {0: "c_dim_0", 1: "c_dim_1"},
    "e": {0: "e_dim_0", 1: "e_dim_1"},
    "f": {0: "f_dim_0", 1: "f_dim_1"},
}

torch.onnx.export(f, (inputs,), "nested.onnx", dynamic_axes=dynamic_axes, input_names=input_names, verbose=True)
onnx_model = onnx.load("nested.onnx")
print(onnx_model.graph.input)

Do we want to do this for users? I guess users could hint us by specifying input_names with flattened inputs.

(3) I suggest that we only support nested tuple/list inputs, and we raise error when we see dict in inputs, because (1) torchscript exporter does not support dictionary if I recall correctly, and (2) that's making conversion even more complicated, given this is just a fallback mechanism.

@justinchuby
Copy link
Collaborator

torchscript does actually support dict inputs, but I think it just flattens the dictionary (may recall incorrectly)

@titaiwangms
Copy link
Collaborator Author

torchscript does actually support dict inputs, but I think it just flattens the dictionary (may recall incorrectly)

Oh I was thinking onnx does not support dict. I will make a change.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that when a user renames input, they can either provide the old name (in the forward function) or the new name in dynamic shapes? What if they conflict?

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@titaiwangms
Copy link
Collaborator Author

I found a bug. hold on merge for now.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

*,
dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any],
input_names: Sequence[str],
) -> dict[str, Any] | tuple[Any] | list[Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring to clarify when this function will or will not work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…port(fallback=True) is triggered (pytorch#139532)

Fixes pytorch#139320

### Summary:
#### (1) Add  `_rename_dynamic_shapes_with_model_inputs` for dynamic_shapes to play along with input_names

* Use model forward signature to rename dynamic_shapes when dynamic_shapes is not nested and dynamic_shapes is directly using the customized name. This solves the issue that torch.export.export expects dynamic_shapes only uses the model input names.
* If the dynamic_shapes is nested, we do nothing.

#### (2) Add `_from_dynamic_shapes_to_dynamic_axes` for fallback

* We flatten dynamic_shapes with leaf defined _pytree.tree_leaves()
~~* If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by `_rename_dynamic_shapes_with_model_inputs` already.~~
* If a dynamic_shapes is provided, input_names is required to assign the names, because dynamic_axes needs it.

Pull Request resolved: pytorch#139532
Approved by: https://github.com/justinchuby
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ONNX] Convert dynamic_shapes to dynamic_axes when fallback is triggered in torch.onnx.export

6 participants