-
Notifications
You must be signed in to change notification settings - Fork 25.7k
use all_weights instead of _parameters in _flat_weights in rnn #15766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
failing tests |
|
Build and asan failures seem unrelated (cannot install moreutils, temporary failure in name resolution), cuda9_cudnn7_py2_test looks real, I'll take a look next week. |
|
Remaining MacOS failures are unrelated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it really looks like a hot patch for a much deeper issue, which is that weight norm and similar don't play well with our module design... Not being able to reliably access parameters in any way other than by name is just bad.
Summary: **WIP** Attempt 2 at #14831 This adds `nn.LSTM` to the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a true `PackedSequence`, instead this PR uses an ordinary `tuple` to stand in for `PackedSequence`. Most of the new code in `rnn.py` is copied to `nn.LSTM` from `nn.RNNBase` to specialize it for LSTM since `hx` is a `Tuple[Tensor, Tensor]` (rather than just a `Tensor` as in the other RNN modules) for LSTM. As a hack it adds an internal annotation `@_parameter_list` to mark that a function returns all the parameters of a module. The weights for `RNN` modules are passed to the corresponding op as a `List[Tensor]`. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone calls `weight_norm` on their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT. Pull Request resolved: #15744 Differential Revision: D14173198 Pulled By: driazati fbshipit-source-id: 4ee8113159b3a8f29a9f56fe661cfbb6b30dffcd
Fixes #15749