-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[jit] Add LSTM to standard library #14831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to use FunctionSchema instead of a string ?
torch/csrc/jit/script/compiler.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What in LSTM cell is the reason for this ?
torch/csrc/jit/script/init.cpp
Outdated
| // self._parameters.values() | ||
| std::shared_ptr<SugaredValue> call(SourceRange loc, Method & caller, at::ArrayRef<NamedValue> inputs, at::ArrayRef<NamedValue> attributes, size_t n_binders) override { | ||
| std::vector<Value*> params; | ||
| const auto& param_list = module_->get_parameters(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might also need to check if the parameters is buffer or not
| return x | ||
|
|
||
|
|
||
| def _unwrap_tuple(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm do we need to think more about this? I think we want to ultimately remove any those unwrap functions if possible, adding more and more might result in a non-easy revert in the future. At least we should think how to provide user a API or something they don't need to care about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to support it at the language level instead of having these hacks I would imagine it would look something like
@torch.jit.script
def fn(x):
# type: (Union[int, float]) -> Union[int, float]
if isinstance(x, int):
return 3
else:
return 3.5And similarly to #14533 we would only emit the branches for the types seen at compile time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm It's hard to control those meta programming conditions since it might be a runtime known value, it might be ok for us as a temp hack to add unwrap/wrap tuples in supporting lstm modules, I guess we just need to be aware to not stacking too much on them
| }), | ||
| Operator( | ||
| // "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault | ||
| FunctionSchema( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same doubt, why you are using FunctionSchema instead of a string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He said in person: no support for Tuples yet in registering operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There needs to be more explanation of what this is trying to accomplish. There seems to be major copy pasting and adding of specific features to get this to work. There is likely a better way.
| }; | ||
| }), | ||
| Operator( | ||
| // "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should fix the segfault :)
| } | ||
| }; | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt it is a good idea to add more SugaredValue types simply to support this one module. This needs a more thorough explanation of what is going on, and why these changes are necessary.
| def __init__(self, *args, **kwargs): | ||
| super(LSTM, self).__init__('LSTM', *args, **kwargs) | ||
|
|
||
| @weak_script_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a massive copy-paste from somewhere, what is going on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that the issue is the function can be called in python with either a Tensor or a Packed Sequence. Torchscript will only be able to handle the Packed Sequence case.
I think another approach would be:
@torch.jit.script
def fn(x):
# type: (Tuple[Tensor, Optional[Tensor])
if isinstance(x, torch.Tensor)
... x is typed as Tensor here, and the if branch will be constant prop'd away
else:
.... x is a tuple here
This would allow us the torchschript code to compile, without changing the python code.
| } | ||
|
|
||
| template <typename T> | ||
| Operation listList(const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a Noop defined on line 37
|
|
||
| #define CREATE_LIST_OPS(decl_type, c_type) \ | ||
| Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \ | ||
| Operator("aten::list(" decl_type "[] a) -> " decl_type "[]", listList<Shared<c_type>>), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want to support list yet... If we do add it, you need to add aliasing information
| }), | ||
| Operator( | ||
| // "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault | ||
| FunctionSchema( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He said in person: no support for Tuples yet in registering operators.
|
Btw this will conflict with #15225 |
Summary: **WIP** Attempt 2 at #14831 This adds `nn.LSTM` to the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a true `PackedSequence`, instead this PR uses an ordinary `tuple` to stand in for `PackedSequence`. Most of the new code in `rnn.py` is copied to `nn.LSTM` from `nn.RNNBase` to specialize it for LSTM since `hx` is a `Tuple[Tensor, Tensor]` (rather than just a `Tensor` as in the other RNN modules) for LSTM. As a hack it adds an internal annotation `@_parameter_list` to mark that a function returns all the parameters of a module. The weights for `RNN` modules are passed to the corresponding op as a `List[Tensor]`. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone calls `weight_norm` on their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT. Pull Request resolved: #15744 Differential Revision: D14173198 Pulled By: driazati fbshipit-source-id: 4ee8113159b3a8f29a9f56fe661cfbb6b30dffcd
[WIP]
Adds support for
torch.nn.LSTMin Scriptself._parameters.values()to aTensor[]aten::Listforint,float, andTensorLSTMaccepts bothPackedSequence(which is aTuple[Tensor, Optional[Tensor]]) orTensorfor[WI{{ its input. While the Python version is unaffected, the TorchScriptLSTMonly supportsTuple[Tensor, Optional[Tensor]]for its input.aten::_wrap_tupleandaten::_unwrap_tupleare used to provide the correct types to the Script compiler, but cannot actually be run