KEMBAR78
[FX] Cannot trace calls with python Enum values · Issue #82135 · pytorch/pytorch · GitHub
Skip to content

[FX] Cannot trace calls with python Enum values #82135

@jamesr66a

Description

@jamesr66a

🐛 Describe the bug

from enum import Enum


class Foo(Enum):
  A = 1
  B = 2

import torch.fx

def leaf_fn(proxy, enum_val):
  return proxy + enum_val

def foo(x):
  return leaf_fn(x, Foo.A)

traced = torch.fx.symbolic_trace(foo)
print(traced)

Gives error:

10 frames
/usr/local/lib/python3.7/dist-packages/torch/fx/proxy.py in create_arg(self, a)
    149             return a
    150 
--> 151         raise NotImplementedError(f"argument of type: {type(a)}")
    152 
    153     @compatibility(is_backward_compatible=True)

NotImplementedError: argument of type: <enum 'Foo'>

Also breaks for codegen:

import torch.fx

g = torch.fx.Graph()

x = g.placeholder('x')
fn = g.call_function(leaf_fn, (x, Foo.A))
g.output(fn)

gm = torch.fx.GraphModule(torch.nn.Module(), g)
  File "<eval_with_key>.0", line 5
    leaf_fn = __main___leaf_fn(x, <Foo.A: 1>);  x = None
                                  ^
SyntaxError: invalid syntax

Versions

Collecting environment information...
PyTorch version: 1.12.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.5
Libc version: glibc-2.26

Python version: 3.7.13 (default, Apr 24 2022, 01:04:09) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.12.0+cu113
[pip3] torchaudio==0.12.0+cu113
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.13.0
[pip3] torchvision==0.13.0+cu113
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @msaroufim @bdhirsh @anijain2305 @chauhang @wconstab

Metadata

Metadata

Assignees

Labels

fxhigh prioritymodule: fxmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions