KEMBAR78
[JIT] LSTM and LSTMCell · Issue #15240 · pytorch/pytorch · GitHub
Skip to content

[JIT] LSTM and LSTMCell #15240

@bmccann

Description

@bmccann

🐛 Bug

JIT traces a wrapper around an LSTMCell that unrolls it over time. Also saves to .pt. But, on loading, gives error.

To Reproduce

import torch
import torch.nn.functional as F
from torch import nn

class LSTM(nn.Module):

    def __init__(self, d_in, d_hid, num_layers=1):
        super().__init__()
        self.rnn = nn.LSTMCell(d_in, d_hid)

    def forward(self, input, hidden):
        for emb_t in input.split(1, dim=1):
            hidden = self.rnn(emb_t.squeeze(1), hidden)
        return hidden

B, T, C = 1, 5, 3 # batch_size, seq_len, channels
f = LSTM(C, C)
z = torch.randn((B, T, C))
h = torch.randn((B, C))
c = torch.randn((B, C))

inputs = (z, (h, c))
y = f(*inputs)

print('tracing')
tf = torch.jit.trace(f, inputs)
ty = tf(*inputs)

print('saving')
tf.save('tf.pt')

print('loading')
tfl = torch.jit.load('tf.pt')
tyl = tfl(*inputs)

Error Message

tracing
saving
loading
Traceback (most recent call last):
File "/decaNLP/trace_lstm_decoder.py", line 33, in
tfl = torch.jit.load('/decaNLP/tf.pt')
File "/opt/conda/lib/python3.7/site-packages/torch/jit/init.py", line 131, in load
torch._C.import_ir_module(module_lookup, f, map_location)
RuntimeError:
attribute lookup is not defined on builtin:
op_version_set = 0
def forward(self,
self: Tensor,
_0: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
hx_1, hx_2, = _0
emb_t_1, emb_t_2, emb_t_3, emb_t_4, emb_t, = torch.split(self, 1, 1)
input_1 = torch.squeeze(emb_t_1, 1)
_1 = torch.addmm(self.rnn.bias_hh, hx_1, torch.t(self.rnn.weight_hh), beta=1, alpha=1)
~~~~~~~~~~~~~~~~ <--- HERE
_2 = torch.addmm(self.rnn.bias_ih, input_1, torch.t(self.rnn.weight_ih), beta=1, alpha=1)
_3 = torch.chunk(torch.add(_2, _1, alpha=1), 4, 1)
_4, _5, _6, _7, = _3
_8 = torch.sigmoid(_4)
_9 = torch.sigmoid(_5)
_10 = torch.tanh(_6)
_11 = torch.sigmoid(_7)
_12 = torch.mul(_8, _10)
hx_4 = torch.add(torch.mul(_9, hx_2), _12, alpha=1)

Expected behavior

I would expect this to work. If there is some reason why it cannot, then I would expect it to error or warn on tracing or saving. If can't have that either, I would expect it to have a more informative error on loading.

Environment

  • PyTorch Version: 1.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.7
  • CUDA/cuDNN version: >9
  • GPU models and configuration: N/A
  • Any other relevant information: N/A

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions