-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add train() / eval() / is_training() to C++ ScriptModule API #16044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/cpp/jit/tests.h
Outdated
| void testEvalModeForLoadedModule() { | ||
| std::string module_path = "dropout_model.pt"; | ||
| std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path); | ||
| // Test eval mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the comments here don't add too much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
torch/csrc/jit/script/module.h
Outdated
| } | ||
| void train(bool on = true) { | ||
| for (auto& submod : get_modules()) { | ||
| submod.value().module->train(on); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can also write submod->module->train(on)
torch/csrc/jit/script/module.h
Outdated
| for (auto& submod : get_modules()) { | ||
| submod.value().module->train(on); | ||
| } | ||
| auto t = autograd::make_variable(at::full({}, on ? 1 : 0, at::kLong)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can replace
auto t = autograd::make_variable(at::full({}, on ? 1 : 0, at::kLong));
with
auto t = torch::tensor(on ? 1 : 0, at::kLong);
To clarify: torch:: factory functions create variables while at:: functions create tensors. torch::tensor/at::tensor is like torch.tensor in python (creates a tensor with the values you give it)`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And then just embed it in the register_parameter call:
register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true);
torch/csrc/jit/script/module.h
Outdated
| } | ||
| bool is_training() { | ||
| if (auto p = find_parameter("training")) { | ||
| return (*p->slot()).item().toLong() == 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you are converting the tensor first to a Scalar, and then to a long. You can just use item<T> to get the value of the tensor without the intermediate Scalar:
(*p->slot()).item<int64_t>()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Some friendly nits which you can address if you feel like it, or not
torch/csrc/jit/script/module.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: p->slot()->
torch/csrc/jit/script/module.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove else block, just return true + comment (since you return in the if block)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
82be773 to
4c36017
Compare
f5a442d to
c03b1e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding
train()/eval()/is_training()to C++ ScriptModule API.