KEMBAR78
[jit] Add LSTM to standard library by driazati · Pull Request #15744 · pytorch/pytorch · GitHub
Skip to content

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Jan 4, 2019

WIP

Attempt 2 at #14831

This adds nn.LSTM to the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a true PackedSequence, instead this PR uses an ordinary tuple to stand in for PackedSequence.

Most of the new code in rnn.py is copied to nn.LSTM from nn.RNNBase to specialize it for LSTM since hx is a Tuple[Tensor, Tensor] (rather than just a Tensor as in the other RNN modules) for LSTM.

As a hack it adds an internal annotation @_parameter_list to mark that a function returns all the parameters of a module. The weights for RNN modules are passed to the corresponding op as a List[Tensor]. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone calls weight_norm on their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 4, 2019
@driazati driazati changed the title Lstm ov [jit] Add LSTM to standard library Jan 4, 2019
@driazati driazati closed this Feb 5, 2019
@driazati driazati reopened this Feb 5, 2019
@driazati driazati requested a review from zdevito February 15, 2019 22:34
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good -- the hack is a bit messy but we discussed this in person and think it is best to get this merged and then work on the underlying functionality that would enable the parameter list to be presented as a first-class object. This requires having non-tensor non-constant model attributes accessible from TorchScript.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mcarilli
Copy link
Collaborator

mcarilli commented Feb 24, 2019

Edit: Disregard all this, we figured out a workaround on our end.

@driazati @zdevito As far as I can tell, you removed 'LSTM' from _rnn_impls https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py#L15-L19 because that layer of indirection isn't necessary anymore for LSTM (now that LSTM has its own full-fledged forward and doesn't need to lean on the forward inherited from RNNBase).

Unfortunately for us, _rnn_impls dict is what Amp interposes on to insert its arg-casting wrapper (we can't patch _VF.* because that comes directly from C++). This means that Amp is broken with for LSTMs with current master. Is it ok if LSTM continues to use the layer of indirection, i.e., restore

_rnn_impls = {
    'LSTM': _VF.lstm,
    'GRU': _VF.gru,
    'RNN_TANH': _VF.rnn_tanh,
    'RNN_RELU': _VF.rnn_relu,
}

and change the _VF.lstm() calls to _rnn_impls['LSTM']()? I realize this is not necessary for you, and in support of our own use case, but it is helpful to us and doesn't seem to do any harm. If you don't object I can PR it.

mcarilli pushed a commit to NVIDIA/apex that referenced this pull request Feb 25, 2019
@driazati
Copy link
Contributor Author

@mcarilli The change was due to some technical limitations in the JIT since it uses the same module code, good you found a fix since we likely wouldn't be able to fix the limitation for a few weeks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants