-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Describe the bug
When using the new python 3.10 style of type annotations for unions (|) as defined in https://peps.python.org/pep-0604/ with torch.jit.script, a RuntimeError is raised.
For the given minimal example:
from typing import Union
import torch
@torch.jit.script
def jit_test(
ok: Union[int, float], bad: int | float
) -> None:
passThe traceback is:
(env) ➜ torch_test python test.py
Traceback (most recent call last):
File "/Users/awb/torch_test/test.py", line 6, in <module>
def jit_test(
File "/Users/awb/torch_test/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
fn = torch._C._jit_script_compile(
RuntimeError:
Expression of type | cannot be used in a type expression:
File "/Users/awb/torch_test/test.py", line 7
@torch.jit.script
def jit_test(
ok: Union[int, float], bad: int | float
~~~~~~~~~~~ <--- HERE
) -> None:
pass
We would expect that the bad argument is treated identically as the ok argument since their typing is functionally equivalent.
Versions
Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.2.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.11 (main, May 17 2023, 17:29:08) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Max
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchvision==0.15.2
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue