KEMBAR78
MAINT Migrates multilabel_margin_loss from THC to ATen (CUDA) by thomasjpfan · Pull Request #60708 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thomasjpfan
Copy link
Contributor

@thomasjpfan thomasjpfan commented Jun 25, 2021

Fixes #24603
Fixes #24602

The implementation should be exactly the same, so it is strange that the benchmarks show such a significant improvement in this PR.

The benchmarks are now the same.

Benchmark script
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

torch.manual_seed(0)
MS_PER_SECOND = 1000


def _time():
    torch.cuda.synchronize()
    return time.perf_counter() * MS_PER_SECOND


device = "cuda"
C = 30
n_runs = 100
reductions = ["none", "sum", "mean"]
Ns = [1_000, 10_000, 100_000]

for reduction, N in product(reductions, Ns):
    total_fwd_time = 0
    total_back_time = 0
    grad_out = torch.randn(N, device=device)
    if reduction != "none":
        grad_out = grad_out[0]

    for _ in range(n_runs):
        input = torch.randn(N, C, device=device, requires_grad=True)
        target = torch.randint(0, C, size=input.size(), device=device)

        # forward
        start = _time()
        result = F.multilabel_margin_loss(input, target, reduction=reduction)
        total_fwd_time += _time() - start


    result = F.multilabel_margin_loss(input, target, reduction=reduction)
    for _ in range(n_runs):
        # backward
        start = _time()
        result.backward(grad_out, retain_graph=True)
        total_back_time += _time() - start

    fwd_avg = total_fwd_time / n_runs
    bwd_avg = total_back_time / n_runs
    print(
        f"input size({N}, {C}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)"
    )

master

input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.41 (ms)
input size(10000, 30), reduction: none, fwd: 1.26 (ms), back: 3.58 (ms)
input size(100000, 30), reduction: none, fwd: 13.15 (ms), back: 34.68 (ms)
input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.38 (ms)
input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 3.53 (ms)
input size(100000, 30), reduction: sum, fwd: 13.04 (ms), back: 34.53 (ms)
input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.38 (ms)
input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 3.52 (ms)
input size(100000, 30), reduction: mean, fwd: 13.12 (ms), back: 34.54 (ms)

this PR

input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.35 (ms)
input size(10000, 30), reduction: none, fwd: 1.22 (ms), back: 2.98 (ms)
input size(100000, 30), reduction: none, fwd: 12.90 (ms), back: 29.32 (ms)
input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.32 (ms)
input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 2.97 (ms)
input size(100000, 30), reduction: sum, fwd: 13.00 (ms), back: 29.17 (ms)
input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.32 (ms)
input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 2.97 (ms)
input size(100000, 30), reduction: mean, fwd: 13.09 (ms), back: 28.91 (ms)

@thomasjpfan thomasjpfan added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 25, 2021
@thomasjpfan thomasjpfan requested a review from ngimel June 25, 2021 01:19
@thomasjpfan thomasjpfan requested a review from ezyang as a code owner June 25, 2021 01:19
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 25, 2021

💊 CI failures summary and remediations

As of commit a340ef2 (more details on the Dr. CI page and at hud.pytorch.org/pr/60708):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ezyang ezyang removed their request for review June 25, 2021 14:18
@thomasjpfan thomasjpfan changed the title MAINT Migrates multilabel_margin_loss from THC to ATen MAINT Migrates multilabel_margin_loss from THC to ATen (CUDA) Jun 30, 2021
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, please resolve conflicts and don't use legacy reduction functions. Also, such large performance gains are indeed suspicious, can you run correctness tests on some bigger sizes (tests are probably run only on very small inputs).

namespace native {

namespace {
const int MULTILABELMARGIN_THREADS = 32;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you've changed the number of threads from 1024 to 32, maybe that's the reason for perf improvement? (Usually, 32 is too small, you need at least 64-128)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the number of threads back to 1024 made the performance of this PR align with master.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we want better perf :-) so if 32 produces correct result we should be using it (or, probably better, 64 or 128)


// reduce
using Op = ReduceAdd<accscalar_t>;
accscalar_t total_sum = reduceBlock<accscalar_t>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of legacy reduceBlock it's better to use BlockReduceSum from block_reduce.cuh (it also doesn't require shared memory)

(target_.size(0) == nframe) && (target_.size(1) == dim),
"inconsistent target size");
TORCH_CHECK(
(is_target_.dim() == 2) && (is_target.size(0) == nframe) &&
Copy link
Collaborator

@ngimel ngimel Jul 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this check is slightly cleaner to write as target_.sizes() == is_target_.sizes()

namespace native {

namespace {
const int MULTILABELMARGIN_THREADS = 1024;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, can you check if reducing number of threads here brings good perf and correct resulst, and we can land?

@thomasjpfan
Copy link
Contributor Author

Here is the same benchmark but with bigger tensors. Using MULTILABELMARGIN_THREADS=128:

input size(1000, 100), reduction: none, fwd: 0.08 (ms), back: 0.20 (ms)
input size(10000, 100), reduction: none, fwd: 0.46 (ms), back: 1.03 (ms)
input size(100000, 100), reduction: none, fwd: 4.30 (ms), back: 9.19 (ms)              
input size(1000000, 100), reduction: none, fwd: 40.16 (ms), back: 92.52 (ms)
input size(1000, 100), reduction: sum, fwd: 0.08 (ms), back: 0.18 (ms)
input size(10000, 100), reduction: sum, fwd: 0.44 (ms), back: 0.97 (ms)
input size(100000, 100), reduction: sum, fwd: 4.10 (ms), back: 9.30 (ms)
input size(1000000, 100), reduction: sum, fwd: 40.53 (ms), back: 93.03 (ms)
input size(1000, 100), reduction: mean, fwd: 0.08 (ms), back: 0.18 (ms)
input size(10000, 100), reduction: mean, fwd: 0.45 (ms), back: 0.98 (ms)
input size(100000, 100), reduction: mean, fwd: 4.13 (ms), back: 9.39 (ms)
input size(1000000, 100), reduction: mean, fwd: 40.89 (ms), back: 93.66 (ms)

and with MULTILABELMARGIN_THREADS=1024:

input size(1000, 100), reduction: none, fwd: 0.43 (ms), back: 1.11 (ms)
input size(10000, 100), reduction: none, fwd: 4.36 (ms), back: 9.90 (ms)
input size(100000, 100), reduction: none, fwd: 40.40 (ms), back: 96.97 (ms)
input size(1000000, 100), reduction: none, fwd: 405.59 (ms), back: 979.68 (ms)
input size(1000, 100), reduction: sum, fwd: 0.40 (ms), back: 1.02 (ms)
input size(10000, 100), reduction: sum, fwd: 4.09 (ms), back: 9.93 (ms)
input size(100000, 100), reduction: sum, fwd: 40.90 (ms), back: 98.14 (ms)
input size(1000000, 100), reduction: sum, fwd: 408.41 (ms), back: 980.30 (ms)
input size(1000, 100), reduction: mean, fwd: 0.40 (ms), back: 1.02 (ms)
input size(10000, 100), reduction: mean, fwd: 4.10 (ms), back: 9.90 (ms)
input size(100000, 100), reduction: mean, fwd: 40.97 (ms), back: 98.10 (ms)
input size(1000000, 100), reduction: mean, fwd: 409.24 (ms), back: 980.26 (ms)

So MULTILABELMARGIN_THREADS=128 is better overall faster. For completeness, these are the results on master:

input size(1000, 100), reduction: none, fwd: 0.43 (ms), back: 1.31 (ms)
input size(10000, 100), reduction: none, fwd: 4.37 (ms), back: 11.79 (ms)
input size(100000, 100), reduction: none, fwd: 40.69 (ms), back: 115.07 (ms)
input size(1000000, 100), reduction: none, fwd: 409.26 (ms), back: 1148.50 (ms)
input size(1000, 100), reduction: sum, fwd: 0.40 (ms), back: 1.22 (ms)
input size(10000, 100), reduction: sum, fwd: 4.08 (ms), back: 11.64 (ms)
input size(100000, 100), reduction: sum, fwd: 40.82 (ms), back: 115.29 (ms)
input size(1000000, 100), reduction: sum, fwd: 407.63 (ms), back: 1151.65 (ms)
input size(1000, 100), reduction: mean, fwd: 0.40 (ms), back: 1.21 (ms)
input size(10000, 100), reduction: mean, fwd: 4.09 (ms), back: 11.64 (ms)
input size(100000, 100), reduction: mean, fwd: 40.90 (ms), back: 115.32 (ms)
input size(1000000, 100), reduction: mean, fwd: 408.44 (ms), back: 1151.58 (ms)

As for correctness, I wrote a script to check MULTILABELMARGIN_THREADS=128 between cpu and cuda, which passes for this PR.

Correctness script
from itertools import product

import torch
import torch.nn.functional as F

torch.manual_seed(0)

C = 100
n_runs = 3
reductions = ["none", "sum", "mean"]
Ns = [10, 100, 1_000, 10_000]

for reduction, N in product(reductions, Ns):

    print(f"Checking {reduction}, ({N}, {C})")
    for _ in range(n_runs):
        grad_out_cpu = torch.randn(N, device="cpu")
        if reduction != "none":
            grad_out_cpu = grad_out_cpu[0]
        grad_out_gpu = grad_out_cpu.to("cuda")

        input_cpu = torch.randn(N, C, requires_grad=True)
        target_cpu = torch.randint(0, C, size=input_cpu.size())
        result_cpu = F.multilabel_margin_loss(input_cpu, target_cpu, reduction=reduction)
        result_cpu.backward(grad_out_cpu)

        input_gpu = input_cpu.to("cuda")
        target_gpu = target_cpu.to("cuda")
        result_gpu = F.multilabel_margin_loss(input_gpu, target_gpu, reduction=reduction)
        result_gpu.backward(grad_out_gpu)

        torch.allclose(result_cpu, result_gpu.to("cpu"))
        torch.allclose(grad_out_cpu, grad_out_gpu.to("cpu"))

@ngimel
Copy link
Collaborator

ngimel commented Jul 22, 2021

Thank you, results look great!

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 9730d91.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate multilabel_margin_loss_forward from the TH to Aten (CUDA) Migrate multilabel_margin_loss_backward from the TH to Aten (CUDA)

4 participants