KEMBAR78
[jit] Add LSTM to standard library by driazati · Pull Request #14831 · pytorch/pytorch · GitHub
Skip to content

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Dec 6, 2018

[WIP]

Adds support for torch.nn.LSTM in Script

  • De-sugaring for self._parameters.values() to a Tensor[]
    • add aten::List for int, float, and Tensor
  • LSTM accepts both PackedSequence (which is a Tuple[Tensor, Optional[Tensor]]) or Tensor for[WI{{ its input. While the Python version is unaffected, the TorchScript LSTM only supports Tuple[Tensor, Optional[Tensor]] for its input.
    • aten::_wrap_tuple and aten::_unwrap_tuple are used to provide the correct types to the Script compiler, but cannot actually be run

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 6, 2018
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to use FunctionSchema instead of a string ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What in LSTM cell is the reason for this ?

// self._parameters.values()
std::shared_ptr<SugaredValue> call(SourceRange loc, Method & caller, at::ArrayRef<NamedValue> inputs, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
std::vector<Value*> params;
const auto& param_list = module_->get_parameters();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also need to check if the parameters is buffer or not

return x


def _unwrap_tuple(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm do we need to think more about this? I think we want to ultimately remove any those unwrap functions if possible, adding more and more might result in a non-easy revert in the future. At least we should think how to provide user a API or something they don't need to care about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support it at the language level instead of having these hacks I would imagine it would look something like

@torch.jit.script
def fn(x):
    # type: (Union[int, float]) -> Union[int, float]
    if isinstance(x, int):
        return 3
    else:
        return 3.5

And similarly to #14533 we would only emit the branches for the types seen at compile time

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm It's hard to control those meta programming conditions since it might be a runtime known value, it might be ok for us as a temp hack to add unwrap/wrap tuples in supporting lstm modules, I guess we just need to be aware to not stacking too much on them

}),
Operator(
// "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault
FunctionSchema(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same doubt, why you are using FunctionSchema instead of a string?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He said in person: no support for Tuples yet in registering operators.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be more explanation of what this is trying to accomplish. There seems to be major copy pasting and adding of specific features to get this to work. There is likely a better way.

};
}),
Operator(
// "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should fix the segfault :)

}
};


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it is a good idea to add more SugaredValue types simply to support this one module. This needs a more thorough explanation of what is going on, and why these changes are necessary.

def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)

@weak_script_method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a massive copy-paste from somewhere, what is going on?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that the issue is the function can be called in python with either a Tensor or a Packed Sequence. Torchscript will only be able to handle the Packed Sequence case.

I think another approach would be:

@torch.jit.script
def fn(x):
    # type: (Tuple[Tensor, Optional[Tensor])
    if isinstance(x, torch.Tensor)
        ... x is typed as Tensor here, and the if branch will be constant prop'd away
    else:
        .... x is a tuple here 

This would allow us the torchschript code to compile, without changing the python code.

}

template <typename T>
Operation listList(const Node* node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a Noop defined on line 37


#define CREATE_LIST_OPS(decl_type, c_type) \
Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
Operator("aten::list(" decl_type "[] a) -> " decl_type "[]", listList<Shared<c_type>>), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want to support list yet... If we do add it, you need to add aliasing information

}),
Operator(
// "aten::_get_packed_sequence(Tensor a, Tensor b) -> (Tensor, Tensor?)", // TODO: using this causes a segfault
FunctionSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He said in person: no support for Tuples yet in registering operators.

@zou3519
Copy link
Contributor

zou3519 commented Dec 18, 2018

Btw this will conflict with #15225

@driazati driazati closed this Jan 3, 2019
facebook-github-bot pushed a commit that referenced this pull request Feb 22, 2019
Summary:
**WIP**

Attempt 2 at #14831

This adds `nn.LSTM` to the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a true `PackedSequence`, instead this PR uses an ordinary `tuple` to stand in for `PackedSequence`.

Most of the new code in `rnn.py` is copied to `nn.LSTM` from `nn.RNNBase` to specialize it for LSTM since `hx` is a `Tuple[Tensor, Tensor]` (rather than just a `Tensor` as in the other RNN modules) for LSTM.

As a hack it adds an internal annotation `@_parameter_list` to mark that a function returns all the parameters of a module. The weights for `RNN` modules are passed to the corresponding op as a `List[Tensor]`. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone calls `weight_norm` on their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT.
Pull Request resolved: #15744

Differential Revision: D14173198

Pulled By: driazati

fbshipit-source-id: 4ee8113159b3a8f29a9f56fe661cfbb6b30dffcd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants