-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
The op aten._weight_norm_interface is not a core Aten op, and it doesn't have a decomposition.
Repro:
import torch
args = (torch.randn(768, 48, 128), torch.randn(1, 1, 128))
def func(x, y):
return torch.ops.aten._weight_norm_interface(x, y, 2)
func(args[0], args[1])
exp = torch.export.export(func, args)
exp = exp.run_decompositions()
print(exp.graph_module.code)
Output:
def forward(self, arg0_1, arg1_1):
_weight_norm_interface = torch.ops.aten._weight_norm_interface.default(arg0_1, arg1_1, 2); arg0_1 = arg1_1 = None
getitem = _weight_norm_interface[0]
getitem_1 = _weight_norm_interface[1]; _weight_norm_interface = None
return (getitem, getitem_1)
Expected output:
I expect aten._weight_norm_interface to be decomposed to other ops that are in the core Aten set.
Versions
NUMA node1 CPU(s): 28-55,84-111
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.4.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.25.0
[pip3] onnx==1.13.1
[pip3] onnxruntime==1.15.1
[pip3] torch==2.2.0a0+gite268139
[pip3] torch-xla==2.2.0
[pip3] torchaudio==2.0.2
[pip3] torchdata==0.7.0
[pip3] torchtext==0.16.0
[pip3] torchvision==0.16.0a0+463cdea
[pip3] triton==2.1.0
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.2.0 pypi_0 pypi
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.6 py310h1128e8f_1
[conda] mkl_random 1.2.2 py310h1128e8f_1
[conda] numpy 1.25.2 pypi_0 pypi
[conda] numpy-base 1.25.0 py310hb5e798b_0
[conda] torch 2.2.0a0+gite268139 dev_0
[conda] torch-xla 2.2.0 dev_0
[conda] torchaudio 2.0.2 pypi_0 pypi
[conda] torchdata 0.7.0 pypi_0 pypi
[conda] torchtext 0.16.0 pypi_0 pypi
[conda] torchvision 0.16.0a0+463cdea pypi_0 pypi
[conda] triton 2.1.0 pypi_0 pypi
cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo