KEMBAR78
run_decompositions fails with torch.scan · Issue #151564 · pytorch/pytorch · GitHub
Skip to content

run_decompositions fails with torch.scan #151564

@xadupre

Description

@xadupre

🐛 Describe the bug

run_decompositions fails where a shape depends on the data.

import torch

def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
    copy = torch.zeros(padded.shape)
    for i in range(pos.shape[0]):
        p = pos[i]
        copy[i, :p] = padded[i, :p]
    return copy

def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
    def pad_row(padded, p):
        row = torch.zeros((padded.shape[0],))
        torch._check(p.item() > 0)
        torch._check(p.item() < padded.shape[0])
        # this check is not always true, we add it anyway to make this dimension >= 2
        # and avoid raising an exception about dynamic dimension in {0, 1}
        if torch.compiler.is_exporting():
            torch._check(p.item() > 1)
        row[: p.item()] = padded[: p.item()]
        return (row,)

    return torch.ops.higher_order.scan(
        pad_row,
        [],
        [padded, pos],
        [],
    )

def select_when_exporting(f, f_scan):
    return f_scan if torch.compiler.is_exporting() else f

class Model(torch.nn.Module):
    def forward(self, images, position):
        return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
            images, position
        )

model = Model()
x = torch.randn((5, 6))
y = torch.arange(5, dtype=torch.int64) + 1
expected = model(x, y)

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(
    model,
    (x, y),
    dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
    strict=False,
)
ep.run_decompositions()

Error:

  File "torch/fx/experimental/symbolic_shapes.py", line 7061, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression u1 < 0 (unhinted: u1 < 0).  (Size-like symbols: none)

Versions

Collecting environment information...
PyTorch version: 2.8.0.dev20250416+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.35

Python version: 3.12.9 (main, Feb  5 2025, 08:49:00) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 538.92
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               20
On-line CPU(s) list:                  0-19
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i7-13800H
CPU family:                           6
Model:                                186
Thread(s) per core:                   2
Core(s) per socket:                   10
Socket(s):                            1
Stepping:                             2
BogoMIPS:                             5836.79
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    Microsoft
Virtualization type:                  full
L1d cache:                            480 KiB (10 instances)
L1i cache:                            320 KiB (10 instances)
L2 cache:                             12.5 MiB (10 instances)
L3 cache:                             24 MiB (1 instance)
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] model-explorer-onnx==0.3.4
[pip3] mypy==1.15.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.3
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.19.0
[pip3] onnx-array-api==0.3.0
[pip3] onnx-extended==0.4.0
[pip3] onnxruntime_extensions==0.13.0
[pip3] onnxruntime-genai-cuda==0.6.0
[pip3] onnxruntime-training==1.22.0+cu126
[pip3] onnxscript==0.3.0.dev20250301
[pip3] optree==0.14.1
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0.dev20250416+cu126
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.6.0.dev20250416+cu126
[pip3] torchmetrics==1.6.2
[pip3] torchvision==0.22.0.dev20250416+cu126
[pip3] triton==3.2.0
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions