KEMBAR78
Add train() / eval() / is_training() to C++ ScriptModule API by yf225 · Pull Request #16044 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Jan 15, 2019

This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding train() / eval() / is_training() to C++ ScriptModule API.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 15, 2019
@yf225 yf225 changed the title [WIP] Add train() / eval() / is_training() to C++ ScriptModule API Add train() / eval() / is_training() to C++ ScriptModule API Jan 16, 2019
@yf225 yf225 requested review from gchanan, goldsborough and zdevito and removed request for goldsborough January 16, 2019 20:06
void testEvalModeForLoadedModule() {
std::string module_path = "dropout_model.pt";
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path);
// Test eval mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comments here don't add too much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

}
void train(bool on = true) {
for (auto& submod : get_modules()) {
submod.value().module->train(on);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also write submod->module->train(on)

for (auto& submod : get_modules()) {
submod.value().module->train(on);
}
auto t = autograd::make_variable(at::full({}, on ? 1 : 0, at::kLong));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replace

auto t = autograd::make_variable(at::full({}, on ? 1 : 0, at::kLong));

with

auto t = torch::tensor(on ? 1 : 0, at::kLong);

To clarify: torch:: factory functions create variables while at:: functions create tensors. torch::tensor/at::tensor is like torch.tensor in python (creates a tensor with the values you give it)`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then just embed it in the register_parameter call:

register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true);

}
bool is_training() {
if (auto p = find_parameter("training")) {
return (*p->slot()).item().toLong() == 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are converting the tensor first to a Scalar, and then to a long. You can just use item<T> to get the value of the tensor without the intermediate Scalar:

(*p->slot()).item<int64_t>()

Copy link
Contributor

@goldsborough goldsborough left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Some friendly nits which you can address if you feel like it, or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: p->slot()->

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove else block, just return true + comment (since you return in the if block)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225 yf225 force-pushed the scriptmodule_eval branch from 82be773 to 4c36017 Compare January 30, 2019 17:27
@yf225 yf225 force-pushed the scriptmodule_eval branch from f5a442d to c03b1e5 Compare January 31, 2019 05:34
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants