-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Description
🐛 Describe the bug
run_decompositions fails where a shape depends on the data.
import torch
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
copy = torch.zeros(padded.shape)
for i in range(pos.shape[0]):
p = pos[i]
copy[i, :p] = padded[i, :p]
return copy
def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
def pad_row(padded, p):
row = torch.zeros((padded.shape[0],))
torch._check(p.item() > 0)
torch._check(p.item() < padded.shape[0])
# this check is not always true, we add it anyway to make this dimension >= 2
# and avoid raising an exception about dynamic dimension in {0, 1}
if torch.compiler.is_exporting():
torch._check(p.item() > 1)
row[: p.item()] = padded[: p.item()]
return (row,)
return torch.ops.higher_order.scan(
pad_row,
[],
[padded, pos],
[],
)
def select_when_exporting(f, f_scan):
return f_scan if torch.compiler.is_exporting() else f
class Model(torch.nn.Module):
def forward(self, images, position):
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
images, position
)
model = Model()
x = torch.randn((5, 6))
y = torch.arange(5, dtype=torch.int64) + 1
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(
model,
(x, y),
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
strict=False,
)
ep.run_decompositions()Error:
File "torch/fx/experimental/symbolic_shapes.py", line 7061, in _evaluate_expr
raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression u1 < 0 (unhinted: u1 < 0). (Size-like symbols: none)
Versions
Collecting environment information...
PyTorch version: 2.8.0.dev20250416+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.35
Python version: 3.12.9 (main, Feb 5 2025, 08:49:00) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 538.92
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13800H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 5836.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] model-explorer-onnx==0.3.4
[pip3] mypy==1.15.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.3
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.19.0
[pip3] onnx-array-api==0.3.0
[pip3] onnx-extended==0.4.0
[pip3] onnxruntime_extensions==0.13.0
[pip3] onnxruntime-genai-cuda==0.6.0
[pip3] onnxruntime-training==1.22.0+cu126
[pip3] onnxscript==0.3.0.dev20250301
[pip3] optree==0.14.1
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250416+cu126
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.6.0.dev20250416+cu126
[pip3] torchmetrics==1.6.2
[pip3] torchvision==0.22.0.dev20250416+cu126
[pip3] triton==3.2.0
[conda] Could not collect
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4