KEMBAR78
Improve numerical stability of LayerNorm by xiaomengy · Pull Request #59987 · pytorch/pytorch · GitHub
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 29 additions & 26 deletions aten/src/ATen/native/cpu/layer_norm_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include <ATen/native/layer_norm.h>

#include <cmath>
#include <tuple>

#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/Parallel.h>
#include <ATen/native/cpu/moments_utils.h>

namespace at {
namespace native {
Expand All @@ -29,30 +30,21 @@ void LayerNormKernelImplInternal(
DCHECK_EQ(X.numel(), M * N);
DCHECK(!gamma.defined() || gamma.numel() == N);
DCHECK(!beta.defined() || beta.numel() == N);
T* X_data = X.data_ptr<T>();
const T* X_data = X.data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.data_ptr<T>() : nullptr;
const T* beta_data = beta.defined() ? beta.data_ptr<T>() : nullptr;
T* Y_data = Y->data_ptr<T>();
T* mean_data = mean->data_ptr<T>();
T* rstd_data = rstd->data_ptr<T>();
const T c = T(1) / static_cast<T>(N);
const bool gamma_null = gamma_data == nullptr;
const bool beta_null = beta_data == nullptr;
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
T* X_ptr = X_data + i * N;
const T* X_ptr = X_data + i * N;
T* Y_ptr = Y_data + i * N;
T mean_val = vec::reduce_all<T>(
[](Vec& x, Vec& y) { return x + y; },
X_ptr,
N);
T rstd_val = vec::map_reduce_all<T>(
[](Vec x) { return x * x; },
[](Vec x, Vec y) { return x + y; },
X_ptr,
N);
mean_val *= c;
rstd_val = std::max(rstd_val * c - mean_val * mean_val, T(0));
T mean_val;
T rstd_val;
std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, N);
rstd_val = T(1) / std::sqrt(rstd_val + eps);
const T scale = rstd_val;
const T bias = -rstd_val * mean_val;
Expand Down Expand Up @@ -117,7 +109,8 @@ void LayerNormBackwardKernelImplInternal(
const T* X_data = X.template data_ptr<T>();
const T* mean_data = mean.template data_ptr<T>();
const T* rstd_data = rstd.template data_ptr<T>();
const T* gamma_data = gamma.defined() ? gamma.template data_ptr<T>() : nullptr;
const T* gamma_data =
gamma.defined() ? gamma.template data_ptr<T>() : nullptr;
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
T* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
Expand All @@ -133,7 +126,8 @@ void LayerNormBackwardKernelImplInternal(
// Parallel along dim0 and reduce dY and X along dim0 to buffer.
// Second path: parallel along dim1 and reduce buffer to dgamma and dbeta.
//
// 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1 cache.
// 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1
// cache.
//
int num_threads = at::get_num_threads();
Tensor buffer = at::empty({0}, X.options());
Expand All @@ -147,10 +141,15 @@ void LayerNormBackwardKernelImplInternal(
// First path of dgamma/dbeta and dX
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads,
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
TORCH_CHECK(
tid < num_threads,
"expect thread id smaller than ",
num_threads,
", got thread id ",
tid);
T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
T* dbeta_buffer_ptr = dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
T* dbeta_buffer_ptr =
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
for (int64_t i = start; i < end; ++i) {
const T* dY_ptr = dY_data + i * N;
const T* X_ptr = X_data + i * N;
Expand All @@ -162,7 +161,9 @@ void LayerNormBackwardKernelImplInternal(
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
// }
vec::map3<T>(
[a, b](Vec dgamma, Vec dy, Vec x) { return dgamma + dy * (Vec(a) * x + Vec(b)); },
[a, b](Vec dgamma, Vec dy, Vec x) {
return dgamma + dy * (Vec(a) * x + Vec(b));
},
dgamma_buffer_ptr,
dgamma_buffer_ptr,
dY_ptr,
Expand Down Expand Up @@ -199,9 +200,7 @@ void LayerNormBackwardKernelImplInternal(
X_ptr,
N);
db = vec::reduce_all<T>(
[](Vec& x, Vec& y) { return x + y; },
dY_ptr,
N);
[](Vec& x, Vec& y) { return x + y; }, dY_ptr, N);
} else {
ds = vec::map3_reduce_all<T>(
[](Vec x, Vec y, Vec z) { return x * y * z; },
Expand All @@ -227,14 +226,18 @@ void LayerNormBackwardKernelImplInternal(
// }
if (gamma_null) {
vec::map2<T>(
[a, b, c](Vec dy, Vec x) { return Vec(a) * dy + Vec(b) * x + Vec(c); },
[a, b, c](Vec dy, Vec x) {
return Vec(a) * dy + Vec(b) * x + Vec(c);
},
dX_ptr,
dY_ptr,
X_ptr,
N);
} else {
vec::map3<T>(
[a, b, c](Vec dy, Vec gamma, Vec x) { return Vec(a) * dy * gamma + Vec(b) * x + Vec(c); },
[a, b, c](Vec dy, Vec gamma, Vec x) {
return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
},
dX_ptr,
dY_ptr,
gamma_data,
Expand Down
16 changes: 8 additions & 8 deletions aten/src/ATen/native/cpu/moments_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void AddMomentsVec(
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// https://en.wikipedia.org/wiki/Pairwise_summation
template <typename T, int64_t kMaxDepth>
std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N) {
std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
using Vec = vec::Vectorized<T>;

constexpr int64_t kVecSize = Vec::size();
Expand Down Expand Up @@ -119,26 +119,26 @@ std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N) {
AddMoments(n, m1_arr[i], m2_arr[i], m0, m1, m2);
}

return std::make_pair(m1, m2 / static_cast<T>(N));
return std::make_pair(m1, m2 / static_cast<T>(N - ddof));
}

template <typename T>
std::pair<T, T> RowwiseMoments(const T* X, int64_t N) {
std::pair<T, T> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
using Vec = vec::Vectorized<T>;
constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = CeilLog2(m);
if (depth <= 4) {
return RowwiseMomentsImpl<T, 4>(X, N);
return RowwiseMomentsImpl<T, 4>(X, N, ddof);
} else if (depth <= 8) {
return RowwiseMomentsImpl<T, 8>(X, N);
return RowwiseMomentsImpl<T, 8>(X, N, ddof);
} else if (depth <= 16) {
return RowwiseMomentsImpl<T, 16>(X, N);
return RowwiseMomentsImpl<T, 16>(X, N, ddof);
} else if (depth <= 32) {
return RowwiseMomentsImpl<T, 32>(X, N);
return RowwiseMomentsImpl<T, 32>(X, N, ddof);
} else {
return RowwiseMomentsImpl<T, 64>(X, N);
return RowwiseMomentsImpl<T, 64>(X, N, ddof);
}
}

Expand Down
62 changes: 44 additions & 18 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include <ATen/native/layer_norm.h>

#include <type_traits>

#include <thrust/tuple.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
Expand Down Expand Up @@ -27,24 +31,35 @@ __global__ void RowwiseMomentsCUDAKernel(
T* mean,
T* rstd) {
using T_ACC = acc_type<T, true>;
__shared__ T_ACC m_shared[C10_WARP_SIZE];
__shared__ T_ACC v_shared[C10_WARP_SIZE];
using WelfordType = WelfordData<T_ACC, int64_t, T_ACC>;
using WelfordOp =
WelfordOps<T_ACC, T_ACC, int64_t, T_ACC, thrust::pair<T_ACC, T_ACC>>;

__shared__
typename std::aligned_storage<sizeof(WelfordType), alignof(WelfordType)>::
type val_shared[C10_WARP_SIZE];
WelfordType* val_shared_ptr = reinterpret_cast<WelfordType*>(val_shared);

const int64_t i = blockIdx.x;
T_ACC sum1 = 0;
T_ACC sum2 = 0;
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
WelfordType val(0, 0, 0, 0);

for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
const int64_t index = i * N + j;
sum1 += static_cast<T_ACC>(X[index]);
sum2 += static_cast<T_ACC>(X[index]) * static_cast<T_ACC>(X[index]);
val = welford_op.reduce(val, static_cast<T_ACC>(X[index]), index);
}
sum1 = cuda_utils::BlockReduceSum<T_ACC>(sum1, m_shared);
sum2 = cuda_utils::BlockReduceSum<T_ACC>(sum2, v_shared);
val = cuda_utils::BlockReduce(
val,
welford_op,
/*identity_element=*/WelfordType(0, 0, 0, 0),
val_shared_ptr);

if (threadIdx.x == 0) {
const T_ACC scale = T_ACC(1) / static_cast<T_ACC>(N);
sum1 *= scale;
sum2 = c10::cuda::compat::max(sum2 * scale - sum1 * sum1, T_ACC(0));
mean[i] = sum1;
rstd[i] = c10::cuda::compat::rsqrt(sum2 + static_cast<T_ACC>(eps));
T_ACC m1;
T_ACC m2;
thrust::tie(m2, m1) = welford_op.project(val);
mean[i] = m1;
rstd[i] = c10::cuda::compat::rsqrt(m2 + static_cast<T_ACC>(eps));
}
}

Expand Down Expand Up @@ -294,8 +309,12 @@ void LayerNormKernelImpl(
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
X.scalar_type(), "LayerNormKernelImpl", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
X.scalar_type(),
"LayerNormKernelImpl",
[&]() {
LayerNormKernelImplInternal<scalar_t>(
X, gamma, beta, M, N, static_cast<scalar_t>(eps), Y, mean, rstd);
});
Expand Down Expand Up @@ -328,7 +347,10 @@ void LayerNormBackwardKernelImplInternal(
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
if (dX_data != nullptr) {
const auto kAccType = (X.scalar_type() == kHalf || X.scalar_type() == kBFloat16) ? kFloat : X.scalar_type();
const auto kAccType =
(X.scalar_type() == kHalf || X.scalar_type() == kBFloat16)
? kFloat
: X.scalar_type();
Tensor ds = at::empty({M}, X.options().dtype(kAccType));
Tensor db = at::empty({M}, X.options().dtype(kAccType));
Tensor scale = at::empty({M}, X.options().dtype(kAccType));
Expand Down Expand Up @@ -413,8 +435,12 @@ void LayerNormBackwardKernelImpl(
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
X.scalar_type(),
"LayerNormBackwardKernelImpl",
[&]() {
LayerNormBackwardKernelImplInternal<scalar_t>(
dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
});
Expand Down
30 changes: 28 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12555,6 +12555,32 @@ def test_LayerNorm_general(self, device):
if self.device_type == 'cuda':
self._test_LayerNorm_cuda_half(device)

@onlyOnCPUAndCUDA
def test_LayerNorm_numeric(self, device):
def layer_norm_ref(X, gamma, beta, normalized_shape, eps):
feature_size = np.prod(normalized_shape)
X_view = X.view(-1, feature_size)
mean = X_view.mean(dim=-1, keepdim=True)
var = X_view.var(dim=-1, unbiased=False, keepdim=True)
Y = (X_view - mean) / torch.sqrt(var + eps)
Y = Y * gamma.view(-1) + beta.view(-1)
return Y.view(*X.size())

normalized_shape = [256, 256, 144]
layer_norm = nn.LayerNorm(normalized_shape).float().to(device)
X = torch.rand(2, *normalized_shape, dtype=torch.float32,
device=device)

Y = layer_norm(X)
Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data,
normalized_shape, layer_norm.eps)
self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)

if self.device_type == 'cuda':
layer_norm.cpu()
Y_cpu = layer_norm(X.cpu())
self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)

@onlyOnCPUAndCUDA
def test_GroupNorm_general(self, device):
self._test_GroupNorm_general(device)
Expand Down Expand Up @@ -12588,8 +12614,8 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps):
return Y.view(*X.size())

batch_size = 1
groups = 4
channels = 32
groups = 2
channels = 8
group_norm = nn.GroupNorm(groups, channels).float().to(device)
X = torch.rand(batch_size, channels, 256, 256, 72,
dtype=torch.float32, device=device)
Expand Down
9 changes: 9 additions & 0 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,15 @@ def fractional_max_pool3d_test(test_case):
check_eval=True,
desc='3d_no_elementwise_affine',
),
dict(
module_name='LayerNorm',
constructor_args=([56, 56, 56], 1e-5, False),
cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
input_size=(4, 56, 56, 56),
cudnn=True,
check_eval=True,
desc='3d_no_affine_large_feature',
),
dict(
module_name='LayerNorm',
constructor_args=([5], 1e-3),
Expand Down