-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
JIT traces a wrapper around an LSTMCell that unrolls it over time. Also saves to .pt. But, on loading, gives error.
To Reproduce
import torch
import torch.nn.functional as F
from torch import nn
class LSTM(nn.Module):
def __init__(self, d_in, d_hid, num_layers=1):
super().__init__()
self.rnn = nn.LSTMCell(d_in, d_hid)
def forward(self, input, hidden):
for emb_t in input.split(1, dim=1):
hidden = self.rnn(emb_t.squeeze(1), hidden)
return hidden
B, T, C = 1, 5, 3 # batch_size, seq_len, channels
f = LSTM(C, C)
z = torch.randn((B, T, C))
h = torch.randn((B, C))
c = torch.randn((B, C))
inputs = (z, (h, c))
y = f(*inputs)
print('tracing')
tf = torch.jit.trace(f, inputs)
ty = tf(*inputs)
print('saving')
tf.save('tf.pt')
print('loading')
tfl = torch.jit.load('tf.pt')
tyl = tfl(*inputs)
Error Message
tracing
saving
loading
Traceback (most recent call last):
File "/decaNLP/trace_lstm_decoder.py", line 33, in
tfl = torch.jit.load('/decaNLP/tf.pt')
File "/opt/conda/lib/python3.7/site-packages/torch/jit/init.py", line 131, in load
torch._C.import_ir_module(module_lookup, f, map_location)
RuntimeError:
attribute lookup is not defined on builtin:
op_version_set = 0
def forward(self,
self: Tensor,
_0: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
hx_1, hx_2, = _0
emb_t_1, emb_t_2, emb_t_3, emb_t_4, emb_t, = torch.split(self, 1, 1)
input_1 = torch.squeeze(emb_t_1, 1)
_1 = torch.addmm(self.rnn.bias_hh, hx_1, torch.t(self.rnn.weight_hh), beta=1, alpha=1)
~~~~~~~~~~~~~~~~ <--- HERE
_2 = torch.addmm(self.rnn.bias_ih, input_1, torch.t(self.rnn.weight_ih), beta=1, alpha=1)
_3 = torch.chunk(torch.add(_2, _1, alpha=1), 4, 1)
_4, _5, _6, _7, = _3
_8 = torch.sigmoid(_4)
_9 = torch.sigmoid(_5)
_10 = torch.tanh(_6)
_11 = torch.sigmoid(_7)
_12 = torch.mul(_8, _10)
hx_4 = torch.add(torch.mul(_9, hx_2), _12, alpha=1)
Expected behavior
I would expect this to work. If there is some reason why it cannot, then I would expect it to error or warn on tracing or saving. If can't have that either, I would expect it to have a more informative error on loading.
Environment
- PyTorch Version: 1.0
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): conda - Build command you used (if compiling from source): N/A
- Python version: 3.7
- CUDA/cuDNN version: >9
- GPU models and configuration: N/A
- Any other relevant information: N/A