-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Hi,
einops were supported in compilation in torch 2.6 and can be collected to graph by dynamo. But seems it is broken in torch 2.7
import torch
import einops
print(f"PyTorch version: {torch.__version__}")
@torch.compile(fullgraph=True)
def fn(x):
return einops.rearrange(x, 'b c -> c b')
x = torch.randn(2, 3)
fn(x)With torch 2.6
$ python rearr.py
PyTorch version: 2.6.0+cu124
With torch 2.7
$ python rearr.py
PyTorch version: 2.7.0+cu126
Traceback (most recent call last):
File "/tmp/rearr.py", line 11, in <module>
fn(x)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Unsupported method call
Explanation: Dynamo does not know how to trace method `symmetric_difference` of class `type`
Hint: Avoid calling `type.symmetric_difference` in your code.
Hint: Please report an issue to PyTorch.
Developer debug context: call_method BuiltinVariable(set) symmetric_difference [SetVariable(), SetVariable()] {}
from user code:
File "/tmp/rearr.py", line 8, in fn
return einops.rearrange(x, 'b c -> c b')
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 600, in rearrange
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 531, in reduce
recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 140, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 311, in _prepare_transformation_recipe
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Dynamo go inside einops.rearrange and step on unsupported symmetric_difference.
Seems the reason is 270ad51
With this commit import einops._torch_specific was limited to einops version < 0.7.0 with comment # version > 0.7.0 does allow_in_graph out of tree. But seems it doesn't work always out of tree.
PS to whom might step on the issue and wish fast workaround. You may add
import einops._torch_specificafter import einops. It fixed the issue.
Versions
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.35
Python version: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 550.127.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8468
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 8
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq dtes64 ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 32 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] triton==3.3.0
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @chauhang @amjames