KEMBAR78
[C++ API] Add named submodule support to nn::Sequential by yf225 · Pull Request #17552 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Feb 27, 2019

Previously, we were not able to assign names to nn::Sequential's submodules. This PR adds this feature to match the Python API. Example use:

Sequential sequential(named_submodule({
      {"linear", Linear(10, 3)},
      {"conv2d", Conv2d(1, 2, 3)},
      {"dropout", Dropout(0.5)},
      {"batchnorm", BatchNorm(5)},
      {"embedding", Embedding(4, 10)},
      {"lstm", LSTM(4, 5)}
}));

It also enables loading parameters of Python nn.Sequential module with custom submodules names into C++ frontend, unblocking pytorch/vision#728 (comment).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225 yf225 requested review from ezyang and gchanan February 27, 2019 23:56

/// Adds a new named module to the `Sequential` container, with name of `std::string` type.
template <typename M>
void push_back(std::string name, M module) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the public string parameter versions of push_back don't match the public non-string versions. I would expect you would just have string overloads for the 3 existing ones?


/// Adds a new named module to the `Sequential` container, with name of `const char*` type.
template <typename M>
void push_back(const char* name, M module) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you have both string and const char* versions of push_back?


Sequential sequential(
std::make_shared<M>(1),
"m2", std::make_shared<M>(2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this behavior is kind of crazy and doesn't match python.

template <typename M>
void push_back(std::string name, M module) {
auto index = add_to_modules(module);
register_module(name, modules_[index].ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can std::move(name) here.


/// Adds a new named module to the `Sequential` container, with name of `const char*` type.
template <typename M>
void push_back(const char* name, M module) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this does an extra copy of the module unnecessarily, and doesn't really follow the convention set above. I'm not sure how expensive the copy constructor is for Module.

It may be better to follow the construct given above, having a separate function for each of the possible types, i.e.

  • std::shared_ptr
  • const ModuleHolder&
  • M&& with M being a module

Alternatively if you think the construct doesn't matter and the extra copy is whatever, then we can simplify the above code, but I think it's important to be consistent.

template <typename M, typename = torch::detail::enable_if_module_t<M>>
size_t add_to_modules(M&& module) {
// Need to get rid of any reference components for make_unique.
using Type = typename std::remove_reference<M>::type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't quite get the point of add_to_modules (are you always properly moving/forwarding parameters around)? Why don't the "final" non-string versions of push_back just call the string versions?

@yf225
Copy link
Contributor Author

yf225 commented Feb 28, 2019

Note to self: make sure we are not making unnecessary copies (by adding tests for it)

Copy link
Contributor

@goldsborough goldsborough left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @gchanan captures my opinion here in that the changes to Sequential are more complicated than they need to be. Essentially you should just need one overload for every existing push_back that has a string as the first parameter. The existing methods can just call into the new methods, using modules_.size() (the next index) as the name. I would do it like this:

  1. Starting with the previous code, turn every push_back into a named version. All the logic should be in the named functions.
  2. Thread the name through to the final insertion into the map
  3. For every named version, add an unnamed overload (the previous signatures), but all they do is call push_back(std::to_string(modules_.size()), <module>).

Then you effectively just have to add three methods and that's it


/// Adds a new named module to the `Sequential` container, with name of `const char*` type.
template <typename M>
void push_back(const char* name, M module) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need const char* in the interface, since we eventually store std::string. It's ok to just have std::string as the argument type and move that properly into the eventual storage


/// Matches `Sequential("m1", Module(1), ...)` case
template <typename Module, typename... Rest>
void push_back(const char* name, Module&& module, Rest&&... rest) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this overload and the whole const char* business, it's not worth it

@gchanan
Copy link
Contributor

gchanan commented Mar 6, 2019

@goldsborough a few opinion questions for you:

  1. what do you think about mixing named and "unnamed" parameters. Should we allow it or not?

  2. Having a constructor that takes an OrderedDict of modules would match the python interface. But it seems like constructing an OrderedDict of modules by hand is annoying because of the different module types. Should we have a helper that does it? I guess you'd basically have to reproduce the module bits of Sequential here to do it though, i.e. three overloads: shared_ptr<ModuleType>, ModuleHolder<M>, AnyModule.

  3. Is there documentation that explains those three overloads anywhere? I think I found some documentation around why there isn't a single module type (which makes sense), but I couldn't find docs about why those 3 are the "magic 3".

@yf225
Copy link
Contributor Author

yf225 commented Mar 7, 2019

Another issue I found is that if I remove the copy constructor (by adding M(const M&) = delete;) for the concrete module type test: https://github.com/pytorch/pytorch/blob/master/test/cpp/api/sequential.cpp#L38-L44, the test will fail to compile, because

template <typename M, typename = torch::detail::enable_if_module_t<M>>
void push_back(M&& module) {
    // Need to get rid of any reference components for make_unique.
    using Type = typename std::remove_reference<M>::type;
    // Here we move (or copy) the module into a new shared_ptr.
    push_back(std::make_shared<Type>(std::forward<M>(module))); // NOTE: This line copies the module
}

(https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/modules/sequential.h#L202) actually copies the module and expects the copy constructor to exist. Ideally this copy should be avoided, and I am investigating how we can achieve so.

@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch from b11b92a to 0038784 Compare March 8, 2019 22:29
@yf225 yf225 mentioned this pull request Mar 11, 2019
@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch from 0038784 to 64c4092 Compare March 11, 2019 19:59
@yf225
Copy link
Contributor Author

yf225 commented Mar 11, 2019

I figured out how to have a simple API for making OrderedDict work, and updated the PR to reflect the new approach.

With the new named_submodules() API, we will be creating an OrderedDict of named submodules much like how we do it with the Python API, which is much better for API parity.

M(const M& other) : torch::nn::Module(other) {
// NOTE: The current implementation expects the module to be copied once
// when it's passed into `std::make_shared<T>()`.
// TODO: Find a way to avoid copying, and then delete the copy constructor.
Copy link
Contributor Author

@yf225 yf225 Mar 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed an issue to document this problem: #17879.

@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch from 1bb5884 to da55865 Compare March 25, 2019 17:05
@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch from 3c9843e to 2924d2d Compare March 25, 2019 20:30
Copy link
Contributor

@goldsborough goldsborough left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments

// `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`,
// if we use the second signature, at the template argument deduction step
// the compiler is not able to deduce the type of `ModuleType` to the type of
// `M(1)` or `M(2)`, since the compiler doesn't actually look into the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try whether std::pair<std::string, AnyModule> works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to change the modules_ordered_dict(...) function to torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(std::initializer_list<std::pair<std::string, AnyModule>> named_modules), however it doesn't seem to work and throws:

../test/cpp/api/sequential.cpp: In member function ‘virtual void SequentialTest_ConstructsFromConcreteType_Test::TestBody()’:
../test/cpp/api/sequential.cpp:72:4: error: could not convert ‘{{"m1", SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(1)}, {std::__cxx11::basic_string<char>(((const char*)"m2"), std::allocator<char>()), SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(2)}, {"m3", SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(3)}}’ from ‘<brace-enclosed initializer list>’ to ‘std::initializer_list<std::pair<std::__cxx11::basic_string<char>, torch::nn::AnyModule> >’

The compiler is not able to match std::initializer_list<std::pair<std::string, AnyModule>> to the nested braced-init list {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}}. So I think the NamedAnyModule approach here is necessary.

inline torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(
std::initializer_list<NamedAnyModule> named_modules) {
torch::OrderedDict<std::string, AnyModule> dict;
for (auto named_module : named_modules) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you're making copies here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make named_module non-const one way or another, because std::initializer_list<T> only provides access to an array of objects of type const T, not T (according to https://en.cppreference.com/w/cpp/utility/initializer_list), but we need named_module to be of non-const type to be able to do std::move(named_module.module()).

In the latest commit, I changed std::move(named_module.module()) to std::move(const_cast<NamedAnyModule&>(named_module).module())) so that we can avoid doing copies in this line.

/// or `push_back("name", module)`, since they should be handled by their respective
/// `push_back` functions.
template <typename First, typename Second, typename... Rest,
typename = torch::disable_if_t<std::is_same<First, std::string>::value ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the const char* guard? I think we're only using std::string in here?

Copy link
Contributor Author

@yf225 yf225 Mar 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't add the guard for const char* here, a call such as sequential->push_back("shared_m1", M(1)) will be template-matched to this push_back(First&& first, Second&& second, Rest&&... rest) method, which is not what we want (we want it to match to push_back(std::string name, M&& module) instead).

}

/// Adds a type-erased `AnyModule` to the `Sequential`.
void push_back(AnyModule any_module) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method still used now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still used by this line:

clone->push_back(module.clone(device));

@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch 3 times, most recently from 4c6a5cc to 0839730 Compare March 27, 2019 20:02
@yf225 yf225 force-pushed the cpp_sequential_named_submodules_split branch from 0839730 to 7875d06 Compare March 27, 2019 20:05
Copy link
Contributor

@goldsborough goldsborough left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in 6ebfbdf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants