-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[jit] Add LSTM to standard library #15744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good -- the hack is a bit messy but we discussed this in person and think it is best to get this merged and then work on the underlying functionality that would enable the parameter list to be presented as a first-class object. This requires having non-tensor non-constant model attributes accessible from TorchScript.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Edit: Disregard all this, we figured out a workaround on our end. @driazati @zdevito As far as I can tell, you removed 'LSTM' from Unfortunately for us, and change the _VF.lstm() calls to |
|
@mcarilli The change was due to some technical limitations in the JIT since it uses the same module code, good you found a fix since we likely wouldn't be able to fix the limitation for a few weeks |
WIP
Attempt 2 at #14831
This adds
nn.LSTMto the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a truePackedSequence, instead this PR uses an ordinarytupleto stand in forPackedSequence.Most of the new code in
rnn.pyis copied tonn.LSTMfromnn.RNNBaseto specialize it for LSTM sincehxis aTuple[Tensor, Tensor](rather than just aTensoras in the other RNN modules) for LSTM.As a hack it adds an internal annotation
@_parameter_listto mark that a function returns all the parameters of a module. The weights forRNNmodules are passed to the corresponding op as aList[Tensor]. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone callsweight_normon their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT.