-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Hi,
In this MultiMarginLoss_backward_kernel
template <int P, typename scalar_t>
__global__ void MultiMarginLoss_backward_kernel(
scalar_t *gradInput, const scalar_t *gradOutput, const scalar_t *input, const int64_t *target,
const scalar_t *weights, int nframe, int dim, bool sizeAverage, scalar_t margin,
bool reduce) {
using acc_t = at::acc_type<scalar_t, true>;
__shared__ acc_t buffer[MULTIMARGIN_THREADS];
int k = blockIdx.x;
const scalar_t *input_k = input + k*dim;
scalar_t *gradInput_k = gradInput + k*dim;
int target_k = static_cast<int>(target[k]);
scalar_t input_target_k = input_k[target_k];
const scalar_t *gradOutput_k = gradOutput;
if (!reduce) {
gradOutput_k += k;
}
const int denom = sizeAverage && reduce ? nframe * dim : dim;
const acc_t g = acc_t(1) / static_cast<acc_t>(denom);
int i_start = threadIdx.x;
int i_end = dim;
int i_step = blockDim.x;
buffer[threadIdx.x] = 0;
for (int i=i_start; i<i_end; i+=i_step) {
scalar_t z = margin - input_target_k + input_k[i];
if (i == target_k) {
continue;
}
if (z > 0) {
acc_t h = (P == 1) ? g : 2*g*z;
if (weights) {
h *= weights[target_k];
}
buffer[threadIdx.x] -= static_cast<scalar_t>(h);
gradInput_k[i] = static_cast<scalar_t>(h);
} else {
gradInput_k[i] = static_cast<scalar_t>(0);
}
}
__syncthreads();
// reduce
if (threadIdx.x == 0) {
acc_t gradInput_target_k = 0;
for (int i=0; i<blockDim.x; i++) {
gradInput_target_k += buffer[i];
}
gradInput_k[target_k] = static_cast<scalar_t>(gradInput_target_k);
}
for (int i=i_start; i<i_end; i+= i_step) {
gradInput_k[i] *= * gradOutput_k;
}
}
Should an __syncthreads(); be added before the for (int i=i_start; i<i_end; i+= i_step) { gradInput_k[i] *= * gradOutput_k; } loop, because when target is not 0, the thread execution order is not guaranteed?
Metadata
Metadata
Assignees
Labels
module: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module