KEMBAR78
nn.MultiheadAttention causes gradients to become NaN under some use cases · Issue #41508 · pytorch/pytorch · GitHub
Skip to content

nn.MultiheadAttention causes gradients to become NaN under some use cases #41508

@wgale

Description

@wgale

🐛 Bug

Using key_padding_mask and attn_mask with nn.MultiheadAttention causes gradients to become NaN under some use cases.

To Reproduce

Steps to reproduce the behavior:

Backwards pass through nn.MultiheadAttention layer where the forward pass used:

  1. attn_mask limiting context in both directions (e.g. bucketed attention)
  2. key_padding_mask where there is padding for at least one sequence (and there is also at least one valid entry for every sequence, as expected)
  3. The dimensions that were masked are not used to calculate the loss
  4. The loss is a real number (not NaN)
import torch
torch.manual_seed(0)
'''Create attention layer'''
attn = torch.nn.MultiheadAttention(embed_dim=1, num_heads=1)
'''Create dummy input'''
x = torch.rand(3, 2, 1)
'''Padding mask, second sequence can only see first embedding'''
key_padding_mask = torch.as_tensor([[False, False, False], [False, True, True]], dtype=torch.bool)
'''Attention mask, bucketing attention to current and previous time steps'''
attn_mask = torch.as_tensor([[0., float('-inf'), float('-inf')], [0., 0., float('-inf')], [float('-inf'), 0., 0.]])
'''Generate attention embedding'''
output, scores = attn(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
print("scores")
print(scores)
'''Create a dummy loss, only use the first embedding which is defined for all sequences'''
loss = output[0, :].sum()
print("loss")
print(loss)
'''Backwards pass and gradients'''
loss.backward()
print("grads")
for n, p in attn.named_parameters():
    print(n, p.grad)


> scores
> tensor([[[1.0000, 0.0000, 0.0000],
>          [0.4468, 0.5532, 0.0000],
>          [0.0000, 0.5379, 0.4621]],
>         [[1.0000, 0.0000, 0.0000],
>         [1.0000, 0.0000, 0.0000],
>          [   nan,    nan,    nan]]], grad_fn=<DivBackward0>)
> loss
> tensor(0.0040, grad_fn=<SumBackward0>)
> grads
> in_proj_weight tensor([[nan],
>         [nan],
>         [nan]])
> in_proj_bias tensor([nan, nan, nan])
> out_proj.weight tensor([[nan]])
> out_proj.bias tensor([2.])

Expected behavior

Gradients should not be NaN

Environment

PyTorch version: 1.5.1
Is debug build: No
CUDA used to build PyTorch: None

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[conda] blas 1.0 mkl
[conda] cpuonly 1.0 0 pytorch
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.18.5 py37ha1c710e_0
[conda] numpy-base 1.18.5 py37hde5b4d6_0
[conda] pytorch 1.5.1 py3.7_cpu_0 [cpuonly] pytorch

Also fails when using GPU.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @zhangguanheng66

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: NaNs and InfsProblems related to NaN and Inf handling in floating pointmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions