KEMBAR78
[C++ Frontend] Make call operator on module holder call forward by goldsborough · Pull Request #15831 · pytorch/pytorch · GitHub
Skip to content

Conversation

@goldsborough
Copy link
Contributor

In Python, you can use the call operator to invoke the forward() method of a module. In C++ this was currently not possible, because I couldn't figure out how to deduce the return type of a module's forward() method under the constraint that forward() may not exist at all (since the base module class in C++ does not mandate a forward() method). I now figured it out, so the call operator can be used.

@ezyang @ebetica

@goldsborough goldsborough requested a review from ebetica as a code owner January 8, 2019 18:10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smessmer would you mind reviewing these templates?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_forward has a quite complex implementation that maybe can be simplified with void_t (it's a c++17 feature but we have c10::guts::void_t for C++11), but that's actually not in this diff but already landed. The templates added here look good.

@goldsborough goldsborough force-pushed the cpp-forward-call-operator branch from e9cc93b to 50bf528 Compare January 9, 2019 17:07
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_forward has a quite complex implementation that maybe can be simplified with void_t (it's a c++17 feature but we have c10::guts::void_t for C++11), but that's actually not in this diff but already landed. The templates added here look good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this handles functions with non-const lvalue references correctly since with declval, you always get rvalue references (you can convert rvalue references to const lvalue references, but not to non-const ones).

A better approach might be to use std::result_of, there I'd have a higher confidence that they implemented it correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please add some static_assert test cases for this functionality. And make sure you test different kinds of references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsborough goldsborough force-pushed the cpp-forward-call-operator branch from 50bf528 to 1cb8217 Compare January 11, 2019 16:36
@goldsborough
Copy link
Contributor Author

I've added tests. @ezyang would appreciate a stamp as it's blocking the C++ frontend tutorial

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsborough is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@soumith soumith added this to the 1.0.1 milestone Jan 15, 2019
@soumith soumith added the cherry-picked This PR was cherry-picked onto a release branch from master label Jan 18, 2019
soumith pushed a commit that referenced this pull request Jan 18, 2019
Summary:
In Python, you can use the call operator to invoke the `forward()` method of a module. In C++ this was currently not possible, because I couldn't figure out how to deduce the return type of a module's `forward()` method under the constraint that `forward()` may not exist at all (since the base module class in C++ does not mandate a `forward()` method). I now figured it out, so the call operator can be used.

ezyang ebetica
Pull Request resolved: #15831

Differential Revision: D13652676

Pulled By: goldsborough

fbshipit-source-id: ccab45a15215dda56460e560f0038781b539135f
soumith pushed a commit that referenced this pull request Jan 29, 2019
Summary:
In Python, you can use the call operator to invoke the `forward()` method of a module. In C++ this was currently not possible, because I couldn't figure out how to deduce the return type of a module's `forward()` method under the constraint that `forward()` may not exist at all (since the base module class in C++ does not mandate a `forward()` method). I now figured it out, so the call operator can be used.

ezyang ebetica
Pull Request resolved: #15831

Differential Revision: D13652676

Pulled By: goldsborough

fbshipit-source-id: ccab45a15215dda56460e560f0038781b539135f
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-picked This PR was cherry-picked onto a release branch from master

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants