-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Hello,
torch.export.export seems to produce invalid code for Tensor.split when used with meta device.
I am not sure if this is something that needs to be fixed, but here is the report anyway.
Please see the snippet below:
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
X.split(2, 1)
# This one works without any problems:
my_module = MyModule()
inp = torch.rand(10, 10)
exported = torch.export.export(my_module, (inp,))
# The generated code is:
# def forward(self, L_X_: torch.Tensor):
# l_x_ = L_X_
# split = l_x_.split(2, 1);
# l_x_ = split = None
# return ()
# This one crashes with the error message `AttributeError: module 'torch._tensor' has no attribute 'split'`:
with torch.device("meta"):
my_module = MyModule()
inp = torch.rand(10, 10)
exported = torch.export.export(my_module, (inp,))
# The generated code is:
# def forward(self, L_X_: torch.Tensor):
# l_x_ = L_X_
# split = torch._tensor.split(l_x_, 2, 1);
# l_x_ = split = None
# return ()If you get an error regarding circular imports, you might want to import import torch._dynamo as well.
Thank you.
Versions
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.4 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.3)
CMake version: version 4.0.1
Libc version: N/A
Python version: 3.13.3 (main, Apr 9 2025, 03:47:57) [Clang 20.1.0 ] (64-bit runtime)
Python platform: macOS-15.4-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Max
Versions of relevant libraries:
[pip3] Could not collect
[conda] numpy 2.2.0 pypi_0 pypi
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4