KEMBAR78
Migrates nll_loss_forward from TH to Aten (CUDA) by thomasjpfan · Pull Request #60097 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thomasjpfan
Copy link
Contributor

@thomasjpfan thomasjpfan commented Jun 16, 2021

Fixes #24610
Aten Umbrella issue #24507
Related to #59765

The performance does not change between this PR and master with the following benchmark script:

Benchmark script
import torch
import torch.nn as nn
import time

torch.manual_seed(0)


def _time():
    torch.cuda.synchronize()
    MS_PER_SECOND = 1000
    return time.perf_counter() * MS_PER_SECOND


device = "cuda"
C = 30
softmax = nn.LogSoftmax(dim=1)
n_runs = 250

for reduction in ["none", "mean", "sum"]:
    for N in [100_000, 500_000, 1_000_000]:
        fwd_t = 0
        bwd_t = 0
        data = torch.randn(N, C, device=device)
        target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
        loss = nn.NLLLoss(reduction=reduction)
        input = softmax(data)

        for i in range(n_runs):
            t1 = _time()
            result = loss(input, target)
            t2 = _time()
            fwd_t = fwd_t + (t2 - t1)
        fwd_avg = fwd_t / n_runs
        print(
            f"input size({N}, {C}), reduction: {reduction} "
            f"forward time is {fwd_avg:.2f} (ms)"
        )
    print()

master

input size(100000, 30), reduction: none forward time is 0.02 (ms)
input size(500000, 30), reduction: none forward time is 0.08 (ms)
input size(1000000, 30), reduction: none forward time is 0.15 (ms)

input size(100000, 30), reduction: mean forward time is 1.81 (ms)
input size(500000, 30), reduction: mean forward time is 8.24 (ms)
input size(1000000, 30), reduction: mean forward time is 16.46 (ms)

input size(100000, 30), reduction: sum forward time is 1.66 (ms)
input size(500000, 30), reduction: sum forward time is 8.24 (ms)
input size(1000000, 30), reduction: sum forward time is 16.46 (ms)

this PR

input size(100000, 30), reduction: none forward time is 0.02 (ms)
input size(500000, 30), reduction: none forward time is 0.08 (ms)
input size(1000000, 30), reduction: none forward time is 0.15 (ms)

input size(100000, 30), reduction: mean forward time is 1.80 (ms)
input size(500000, 30), reduction: mean forward time is 8.24 (ms)
input size(1000000, 30), reduction: mean forward time is 16.46 (ms)

input size(100000, 30), reduction: sum forward time is 1.66 (ms)
input size(500000, 30), reduction: sum forward time is 8.24 (ms)
input size(1000000, 30), reduction: sum forward time is 16.46 (ms)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 16, 2021

💊 CI failures summary and remediations

As of commit e7877ec (more details on the Dr. CI page and at hud.pytorch.org/pr/60097):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jun 23 05:28:19 AssertionError: False is not tr... was 1.0 (1.0 vs. 0.0), which occurred at index 0.
Jun 23 05:28:19 ----------------------------------------------------------------------
Jun 23 05:28:19 Traceback (most recent call last):
Jun 23 05:28:19   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 397, in instantiated_test
Jun 23 05:28:19     result = test_fn(self, *args)
Jun 23 05:28:19   File "/var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py", line 458, in test_transpose_inplace_view
Jun 23 05:28:19     self.assertEqual(t[1, 0], v[0, 1])
Jun 23 05:28:19   File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 605, in assertEqual
Jun 23 05:28:19     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
Jun 23 05:28:19   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1407, in assertEqual
Jun 23 05:28:19     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
Jun 23 05:28:19 AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0.001 and atol=0.001, found 1 element(s) (out of 1) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 1.0 (1.0 vs. 0.0), which occurred at index 0.
Jun 23 05:28:19 
Jun 23 05:28:19 ----------------------------------------------------------------------
Jun 23 05:28:19 Ran 138 tests in 3.230s
Jun 23 05:28:19 
Jun 23 05:28:19 FAILED (failures=2, skipped=102)
Jun 23 05:28:19 
Jun 23 05:28:19 Generating XML reports...
Jun 23 05:28:19 Generated XML report: test-reports/python-unittest/test.......test.test_view_ops/TEST-TestViewOpsXLA-20210623052816.xml
Jun 23 05:28:19 + cleanup
Jun 23 05:28:19 + retcode=1

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@thomasjpfan thomasjpfan added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 16, 2021
@thomasjpfan thomasjpfan changed the title Migrates nllloss_forward to aten Migrates nllloss_forward to aten (CUDA) Jun 16, 2021
@thomasjpfan thomasjpfan marked this pull request as draft June 16, 2021 21:46
@thomasjpfan thomasjpfan changed the title Migrates nllloss_forward to aten (CUDA) [WIP] Migrates nllloss_forward to aten (CUDA) Jun 16, 2021
@thomasjpfan thomasjpfan changed the title [WIP] Migrates nllloss_forward to aten (CUDA) Migrates nllloss_forward to aten (CUDA) Jun 17, 2021
@thomasjpfan thomasjpfan marked this pull request as ready for review June 17, 2021 00:38
const int64_t n_dims = input.dim();
TORCH_CHECK(n_dims > 0 && n_dims <= 2, "input tensor should be 1D or 2D");
TORCH_CHECK(
target.dim() == 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if target.dim() is 0? Previous code errored on this too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original implementation did not raise this specific error when target.dim() == 0:

if (THCIndexTensor_(nDimension)(state, target) > 1) {
THError("multi-target not supported");
}

I updated the check to TORCH_CHECK(target.dim() <= 1) to not change the original behavior.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please beef up testing for supported/unsupported dims?

@thomasjpfan thomasjpfan changed the title Migrates nllloss_forward to aten (CUDA) Migrates nll_loss_forward to aten (CUDA) Jun 18, 2021
@thomasjpfan thomasjpfan changed the title Migrates nll_loss_forward to aten (CUDA) Migrates nll_loss_forward from TH to Aten (CUDA) Jun 18, 2021
}
if (total_weight_acc != 0) {
*output = static_cast<scalar_t>(output_acc / total_weight_acc);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need else here, otherwise output won't be set for total_weight_acc = 0, and please add tests for all these cases.

Copy link
Contributor Author

@thomasjpfan thomasjpfan Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I redid the implementation a little to only set *output once and added tests for the nan case and the size_average && total_weight_acc == 0 case.

@ngimel
Copy link
Collaborator

ngimel commented Jun 22, 2021

Test errors are real

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@thomasjpfan
Copy link
Contributor Author

thomasjpfan commented Jun 23, 2021

While looking at a test error I saw that we already have test for zero weights and empty tensors: test_nll_loss_empty_tensor_reduction_mean and test_nll_loss_total_weight_is_zero, so I removed the tests I added in this PR that tested the same thing.

The rest of the public tests seems unrelated to this PR.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Jun 23, 2021
Summary:
Addresses a part of #59765

This PR adds byte support for nll_loss on the CPU for `input.dim() == 2`.

CUDA support will be implemented when `nll_loss` migration to CUDA is completed in #60299 and #60097

Pull Request resolved: #60308

Reviewed By: VitalyFedyunin

Differential Revision: D29329458

Pulled By: jbschlosser

fbshipit-source-id: d3585c4966030bc61e451f8aa817406a8a3acf47
@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 99b6411.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate nll_loss_forward from the TH to Aten (CUDA)

5 participants