-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
This is another neoverse-v1 / gcc optimization related issue. Very similar to #137597 but not the same issue.
I have checked and this error goes away when we add "-fno-tree-vectorize" compile flag.
Command:
g++ /tmp/tmpuqk7lj9j/zx/czx2eyturb6j6m727xhvknkjbdu3y5nqqk66wgxcjkwnxuzvpm5r.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_NEON -shared -fPIC -O3 -DNDEBUG -ffast-math -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/opt/_internal/cpython-3.10.14/include/python3.10 -I/opt/_internal/cpython-3.10.14/include/python3.10 -I/aarch64_env_test/lib/python3.10/site-packages/torch/include -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/TH -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/THC -I/aarch64_env_test/lib/python3.10/site-packages/torch/include -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/TH -I/aarch64_env_test/lib/python3.10/site-packages/torch/include/THC -D_GLIBCXX_USE_CXX11_ABI=0 -ltorch -ltorch_cpu -ltorch_python -lc10 -lgomp -L/opt/_internal/cpython-3.10.14/lib -L/aarch64_env_test/lib/python3.10/site-packages/torch/lib -L/aarch64_env_test/lib/python3.10/site-packages/torch/lib -L/aarch64_env_test/lib/python3.10/site-packages/torch/lib -o /tmp/tmpuqk7lj9j/zx/czx2eyturb6j6m727xhvknkjbdu3y5nqqk66wgxcjkwnxuzvpm5r.so
Output:
during GIMPLE pass: slp
/tmp/tmpuqk7lj9j/zx/czx2eyturb6j6m727xhvknkjbdu3y5nqqk66wgxcjkwnxuzvpm5r.cpp: In function ‘void kernel(const float*, const float*, float*, int64_t*, int64_t, int64_t)’:
/tmp/tmpuqk7lj9j/zx/czx2eyturb6j6m727xhvknkjbdu3y5nqqk66wgxcjkwnxuzvpm5r.cpp:3:18: internal compiler error: in vect_get_vector_types_for_stmt, at tree-vect-stmts.c:12252
3 | extern "C" void kernel(const float* in_ptr0,
| ^~~~~~
Please submit a full bug report,
with preprocessed source if appropriate.
To reproduce
On neoverse-v1 with gcc10.2
To execute this test, run the following from the base repo dir:
python test/inductor/test_torchinductor_codegen_dynamic_shapes.py DynamicShapesCodegenCpuTests.test_fractional_max_pool2d3_dynamic_shapes_cpu
There are are a number of other tests that fail with the same error, but too many to include.
Versions
Collecting environment information...
PyTorch version: 2.6.0.dev20241010+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: CentOS Linux 7 (AltArch) (aarch64)
GCC version: (GCC) 10.2.1 20210130 (Red Hat 10.2.1-11)
Clang version: Could not collect
CMake version: version 3.30.4
Libc version: glibc-2.17
Python version: 3.10.15 (main, Sep 28 2024, 23:26:50) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)] (64-bit runtime)
Python platform: Linux-6.8.0-1015-aws-aarch64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: aarch64
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 1
NUMA node(s): 1
Model: 1
BogoMIPS: 2100.00
L1d cache: 64K
L1i cache: 64K
L2 cache: 1024K
L3 cache: 32768K
NUMA node0 CPU(s): 0-47
Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm ssbs paca pacg dcpodp svei8mm svebf16 i8mm bf16 dgh rng
Versions of relevant libraries:
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] onnx==1.16.1
[pip3] onnxscript==0.1.0.dev20240817
[pip3] optree==0.13.0
[pip3] torch==2.6.0.dev20241011+cpu
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @malfet @snadampal @milpuz01