-
Notifications
You must be signed in to change notification settings - Fork 25.7k
add reduce arg to PoissonNLLLoss #3770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this, @kevinzakka. This looks great! I had a few minor comments
torch/nn/functional.py
Outdated
if size_average: | ||
return torch.mean(loss) | ||
if not reduce: | ||
return loss |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
if size_average: | ||
return torch.mean(loss) | ||
else: | ||
return torch.sum(loss) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
module_name='PoissonNLLLoss', | ||
input_size=(2, 3, 4, 5), | ||
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(), | ||
desc='non_full_loss', |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
module_name='PoissonNLLLoss', | ||
constructor_args=(True, False, True, 1e-8, False), | ||
input_size=(2, 3, 4, 5), | ||
target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch - the reference function isn't necessary because PoissonNLLLoss is written in Python. One more minor comment below.
The test seems to be failing on the CI. You can run it directly with
python test/test_nn.py TestNN.test_poissonnllloss_no_reduce
and see what's up.
test/test_nn.py
Outdated
constructor=wrap_functional( | ||
lambda i: F.poisson_nll_loss(i, t.type_as(i), reduce=False)), | ||
input_fn=lambda: torch.rand(10, 10), | ||
reference_fn=lambda i, _: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
mask = target > 1 | ||
loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask] | ||
if not reduce: | ||
return torch.mean(loss, 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/common_nn.py
Outdated
), | ||
] | ||
|
||
def poissonnllloss_reference(input, target, log_input=True, full=False, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thanks @kevinzakka
Two very minor comments and it'll be good to go :)
test/test_nn.py
Outdated
dict( | ||
module_name='PoissonNLLLoss', | ||
constructor_args=(False, True, True), | ||
constructor_args=(False, True, True, 1e-8, True), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/functional.py
Outdated
log_input=False. Default: 1e-8 | ||
reduce (bool, optional): By default, the losses are averaged | ||
over observations for each minibatch, or summed, depending on | ||
size_average. When reduce is False, returns a loss per batch |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Done, thanks tons @zou3519 |
@pytorchbot test this please |
Thanks @kevinzakka ! |
* add reduce arg to PoissonNLLLoss * fixed comments except reference function * fixed unit test * small indentation fix * fixing last comments by richard * lint check * another linting issue
As per #264. When reduce is False, PoissonNLLLoss outputs a loss per element of the input tensor. When reduce is True (default), the current behavior is kept.
This did not require changing any C or CUDA files as PoissonNLLLoss is implemented purely in python.