-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[C++ API] Add named submodule support to nn::Sequential #17552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C++ API] Add named submodule support to nn::Sequential #17552
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
|
||
| /// Adds a new named module to the `Sequential` container, with name of `std::string` type. | ||
| template <typename M> | ||
| void push_back(std::string name, M module) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why the public string parameter versions of push_back don't match the public non-string versions. I would expect you would just have string overloads for the 3 existing ones?
|
|
||
| /// Adds a new named module to the `Sequential` container, with name of `const char*` type. | ||
| template <typename M> | ||
| void push_back(const char* name, M module) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you have both string and const char* versions of push_back?
test/cpp/api/sequential.cpp
Outdated
|
|
||
| Sequential sequential( | ||
| std::make_shared<M>(1), | ||
| "m2", std::make_shared<M>(2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this behavior is kind of crazy and doesn't match python.
| template <typename M> | ||
| void push_back(std::string name, M module) { | ||
| auto index = add_to_modules(module); | ||
| register_module(name, modules_[index].ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can std::move(name) here.
|
|
||
| /// Adds a new named module to the `Sequential` container, with name of `const char*` type. | ||
| template <typename M> | ||
| void push_back(const char* name, M module) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this does an extra copy of the module unnecessarily, and doesn't really follow the convention set above. I'm not sure how expensive the copy constructor is for Module.
It may be better to follow the construct given above, having a separate function for each of the possible types, i.e.
- std::shared_ptr
- const ModuleHolder&
- M&& with M being a module
Alternatively if you think the construct doesn't matter and the extra copy is whatever, then we can simplify the above code, but I think it's important to be consistent.
| template <typename M, typename = torch::detail::enable_if_module_t<M>> | ||
| size_t add_to_modules(M&& module) { | ||
| // Need to get rid of any reference components for make_unique. | ||
| using Type = typename std::remove_reference<M>::type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't quite get the point of add_to_modules (are you always properly moving/forwarding parameters around)? Why don't the "final" non-string versions of push_back just call the string versions?
|
Note to self: make sure we are not making unnecessary copies (by adding tests for it) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @gchanan captures my opinion here in that the changes to Sequential are more complicated than they need to be. Essentially you should just need one overload for every existing push_back that has a string as the first parameter. The existing methods can just call into the new methods, using modules_.size() (the next index) as the name. I would do it like this:
- Starting with the previous code, turn every push_back into a named version. All the logic should be in the named functions.
- Thread the name through to the final insertion into the map
- For every named version, add an unnamed overload (the previous signatures), but all they do is call
push_back(std::to_string(modules_.size()), <module>).
Then you effectively just have to add three methods and that's it
|
|
||
| /// Adds a new named module to the `Sequential` container, with name of `const char*` type. | ||
| template <typename M> | ||
| void push_back(const char* name, M module) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need const char* in the interface, since we eventually store std::string. It's ok to just have std::string as the argument type and move that properly into the eventual storage
|
|
||
| /// Matches `Sequential("m1", Module(1), ...)` case | ||
| template <typename Module, typename... Rest> | ||
| void push_back(const char* name, Module&& module, Rest&&... rest) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this overload and the whole const char* business, it's not worth it
|
@goldsborough a few opinion questions for you:
|
|
Another issue I found is that if I remove the copy constructor (by adding (https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/modules/sequential.h#L202) actually copies the module and expects the copy constructor to exist. Ideally this copy should be avoided, and I am investigating how we can achieve so. |
b11b92a to
0038784
Compare
0038784 to
64c4092
Compare
|
I figured out how to have a simple API for making OrderedDict work, and updated the PR to reflect the new approach. With the new |
test/cpp/api/sequential.cpp
Outdated
| M(const M& other) : torch::nn::Module(other) { | ||
| // NOTE: The current implementation expects the module to be copied once | ||
| // when it's passed into `std::make_shared<T>()`. | ||
| // TODO: Find a way to avoid copying, and then delete the copy constructor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed an issue to document this problem: #17879.
1bb5884 to
da55865
Compare
…quential_named_submodules_split
3c9843e to
2924d2d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments
| // `modules_ordered_dict({{"m1", M(1)}, {"m2", M(2)}})`, | ||
| // if we use the second signature, at the template argument deduction step | ||
| // the compiler is not able to deduce the type of `ModuleType` to the type of | ||
| // `M(1)` or `M(2)`, since the compiler doesn't actually look into the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you try whether std::pair<std::string, AnyModule> works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to change the modules_ordered_dict(...) function to torch::OrderedDict<std::string, AnyModule> modules_ordered_dict(std::initializer_list<std::pair<std::string, AnyModule>> named_modules), however it doesn't seem to work and throws:
../test/cpp/api/sequential.cpp: In member function ‘virtual void SequentialTest_ConstructsFromConcreteType_Test::TestBody()’:
../test/cpp/api/sequential.cpp:72:4: error: could not convert ‘{{"m1", SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(1)}, {std::__cxx11::basic_string<char>(((const char*)"m2"), std::allocator<char>()), SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(2)}, {"m3", SequentialTest_ConstructsFromConcreteType_Test::TestBody()::M(3)}}’ from ‘<brace-enclosed initializer list>’ to ‘std::initializer_list<std::pair<std::__cxx11::basic_string<char>, torch::nn::AnyModule> >’The compiler is not able to match std::initializer_list<std::pair<std::string, AnyModule>> to the nested braced-init list {{"m1", M(1)}, {std::string("m2"), M(2)}, {"m3", M(3)}}. So I think the NamedAnyModule approach here is necessary.
| inline torch::OrderedDict<std::string, AnyModule> modules_ordered_dict( | ||
| std::initializer_list<NamedAnyModule> named_modules) { | ||
| torch::OrderedDict<std::string, AnyModule> dict; | ||
| for (auto named_module : named_modules) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that you're making copies here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make named_module non-const one way or another, because std::initializer_list<T> only provides access to an array of objects of type const T, not T (according to https://en.cppreference.com/w/cpp/utility/initializer_list), but we need named_module to be of non-const type to be able to do std::move(named_module.module()).
In the latest commit, I changed std::move(named_module.module()) to std::move(const_cast<NamedAnyModule&>(named_module).module())) so that we can avoid doing copies in this line.
| /// or `push_back("name", module)`, since they should be handled by their respective | ||
| /// `push_back` functions. | ||
| template <typename First, typename Second, typename... Rest, | ||
| typename = torch::disable_if_t<std::is_same<First, std::string>::value || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the const char* guard? I think we're only using std::string in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't add the guard for const char* here, a call such as sequential->push_back("shared_m1", M(1)) will be template-matched to this push_back(First&& first, Second&& second, Rest&&... rest) method, which is not what we want (we want it to match to push_back(std::string name, M&& module) instead).
| } | ||
|
|
||
| /// Adds a type-erased `AnyModule` to the `Sequential`. | ||
| void push_back(AnyModule any_module) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this method still used now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is still used by this line:
| clone->push_back(module.clone(device)); |
…quential_named_submodules_split
4c6a5cc to
0839730
Compare
0839730 to
7875d06
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Previously, we were not able to assign names to
nn::Sequential's submodules. This PR adds this feature to match the Python API. Example use:It also enables loading parameters of Python
nn.Sequentialmodule with custom submodules names into C++ frontend, unblocking pytorch/vision#728 (comment).