-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ONNX] Insert contiguous node between transpose and view before calling run_decompositions #137340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137340
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 65057d6 with merge base a063a82 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Is this intended to be merged in? If we can address the root cause that would be the best imo, since we have had issues with too many FX passes where we don’t know which one to look when an error is introduced. I can run a git bisect today to try to pin down the commit. If we do decide to merge this as a temporary fix, I recommend moving the logic to _fx_passes, where all fx passes live. |
|
The fix is just one pass on the fx graph. It is done inplace. It should be very fast. I can move the function to _fx_passes.py if that's what you mean. |
I am more concerned about complexity and robustness over speed. But since this should be a temporary fix, it should be useful to unblock us for now. I would just make it very clear that it should be removed when the issue is resolved and should be removed before the 2.6 release |
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
|
I tested that changing to contiguous works. with graph.inserting_after(node):
new_node = graph.call_function(torch.ops.aten.contiguous.default, args=(node,))
node.replace_all_uses_with(new_node)
# new_node is replaced as well so we manually revert the replacement
new_node.update_arg(0, node)
node.users = {new_node: None} |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Works around #136543.
This fix solves the issue only in the context of the ONNX exporter but this issue happens in other context.
The bug happens when method
run_decompositionsis called. The failing pattern is assumed to beview(transpose(x, ...)). This pattern is replaced byview(flatten(transpose(x, ..))). By changing the dimensions, the strides are updated as well andrun_decompositionsdoes not fail anymore. It would be inefficient on a 1D tensor but then transpose would not be used. The extra node appears in the final onnx graph but is removed after optimization. The final onnx graph should not be impacted and no performance loss should be observed for the onnx model.