KEMBAR78
MAINT Migrates rrelu_with_noise from THC to ATen on Cuda by thomasjpfan · Pull Request #57864 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thomasjpfan
Copy link
Contributor

Fixes #24618
Related to #24507

Benchmark script:
import torch
import torch.nn as nn
import time

torch.manual_seed(0)
def _time():
    torch.cuda.synchronize()
    return time.time()

device = "cuda"
m = nn.RReLU().cuda()

for n in [100, 10_000, 100_000]:
    fwd_t = 0
    bwd_t = 0
    input = torch.randn(128, n, device=device)
    grad_output = torch.ones(128, n, device=device)
    for i in range(10000):
        t1 = _time()
        output = m(input)
        t2 = _time()
        fwd_t = fwd_t + (t2 -t1)
    fwd_avg = fwd_t / 10000 * 1000
    print(f"input size(128, {n}) forward time is {fwd_avg:.2f} (ms)")

Results from benchmark:

This PR

input size(128, 100) forward time is 0.01 (ms)
input size(128, 10000) forward time is 0.06 (ms)
input size(128, 100000) forward time is 0.54 (ms)

On master

input size(128, 100) forward time is 0.01 (ms)
input size(128, 10000) forward time is 0.08 (ms)
input size(128, 100000) forward time is 0.66 (ms)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 7, 2021

💊 CI failures summary and remediations

As of commit 88bdcd7 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ngimel ngimel self-requested a review May 7, 2021 22:47
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2021
@ezyang ezyang removed their request for review May 10, 2021 15:24
@ezyang
Copy link
Contributor

ezyang commented May 10, 2021

removed myself

@thomasjpfan thomasjpfan added the module: nn Related to torch.nn label May 11, 2021
inline scalar_t __device__ curand_uniform_type(curandStatePhilox4_32_10_t *state);

template <>
inline THHalf __device__ curand_uniform_type<THHalf>(curandStatePhilox4_32_10_t *state) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use legacy THHalf type, use at::Half instead. There are implicit conversions between at::Half and float, so ScalarConvert is not necessary

template <>
inline THHalf __device__ curand_uniform_type<THHalf>(curandStatePhilox4_32_10_t *state) {
auto rand = curand_uniform4(state);
return ScalarConvert<float, THHalf>::to(rand.x);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using only .x out of 4 generated numbers is wasteful, you can have an unroll loop in the kernel that would use all the values, you can take a look e.g. at the non-vectorized fused_dropout_kernel in Dropout.cu

if (input[i] <= 0)
{
scalar_t r = curand_uniform_type<scalar_t>(&state);
r = ScalarConvert<double, scalar_t>::to(r * (b - a) + a);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having double is usually perf penalty, it should be scalar_t or at most accscalar_t

else
{
output[i] = input[i];
noise[i] = ScalarConvert<int, scalar_t>::to(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ScalarConvert please


CUDA_KERNEL_LOOP(i, n)
{
if (input[i] <= 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid warp divergence, you should generate randoms for every input, and then only diverge on fast operations like computing output and noise

@thomasjpfan thomasjpfan force-pushed the rrelu_inplace_aten_migrate branch from c04f62b to bdd6e87 Compare June 11, 2021 17:01
@thomasjpfan
Copy link
Contributor Author

Thank you for the review @ngimel ! I updated the PR to use unrolling. I ran the following benchmark:

Benchmark script:
import torch
import torch.nn as nn
import time

torch.manual_seed(0)
def _time():
    torch.cuda.synchronize()
    return time.time()

device = "cuda"
m = nn.RReLU().cuda()
n_runs = 1_000

for n in [10_000, 100_000, 1_000_000]:
    fwd_t = 0
    bwd_t = 0
    input = torch.randn(128, n, device=device)
    grad_output = torch.ones(128, n, device=device)
    for i in range(n_runs):
        t1 = _time()
        output = m(input)
        t2 = _time()
        fwd_t = fwd_t + (t2 -t1)
    fwd_avg = fwd_t / n_runs * 1000
    print(f"input size(128, {n}) forward time is {fwd_avg:.2f} (ms)")

Results from benchmark:

This PR

input size(128, 10000) forward time is 0.06 (ms)
input size(128, 100000) forward time is 0.43 (ms)
input size(128, 1000000) forward time is 4.17 (ms)

On master

input size(128, 10000) forward time is 0.09 (ms)
input size(128, 100000) forward time is 0.69 (ms)
input size(128, 1000000) forward time is 6.66 (ms)

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good, I left minor comments.

double range = upper - lower;

for (int linear_index = idx; linear_index < rounded_size; linear_index += grid_stride) {
auto rand = random_func(&state);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add static assert here that sizeof(rand)/sizeof(rand.x) == unroll_factor? Otherwise your (&rand.x)[ii] access is unsafe.

checkAllSameGPU("rrelu_with_noise_out_cuda", {self_arg, noise_arg, output_arg});

auto input = self.contiguous();
auto noise_ = noise.contiguous();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rrelu_with_noise_out_cuda is a user facing function, which means that output can also be discontiguous here.

output, input, noise_, lower, upper, generator);
});
} else {
auto lower_tensor = scalar_to_tensor(lower);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to convert Scalar to tensor here, instead convert to regular type (using .to<double>) and negative_slope back to Scalar

auto rand = random_func(&state);

// ensure that (&rand.x)[ii] is safe
CUDA_KERNEL_ASSERT(sizeof(rand)/sizeof(rand.x) == unroll_factor);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be static_assert (to be done at compile time), not runtime assert.

@ngimel
Copy link
Collaborator

ngimel commented Jun 16, 2021

Can you please try rebasing, to get CI signal?

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in a0ad4c2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Migrate rrelu_with_noise and rrelu_with_noise_ from the TH to Aten (CUDA)

5 participants