-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
Using key_padding_mask and attn_mask with nn.MultiheadAttention causes gradients to become NaN under some use cases.
To Reproduce
Steps to reproduce the behavior:
Backwards pass through nn.MultiheadAttention layer where the forward pass used:
- attn_mask limiting context in both directions (e.g. bucketed attention)
- key_padding_mask where there is padding for at least one sequence (and there is also at least one valid entry for every sequence, as expected)
- The dimensions that were masked are not used to calculate the loss
- The loss is a real number (not NaN)
import torch
torch.manual_seed(0)
'''Create attention layer'''
attn = torch.nn.MultiheadAttention(embed_dim=1, num_heads=1)
'''Create dummy input'''
x = torch.rand(3, 2, 1)
'''Padding mask, second sequence can only see first embedding'''
key_padding_mask = torch.as_tensor([[False, False, False], [False, True, True]], dtype=torch.bool)
'''Attention mask, bucketing attention to current and previous time steps'''
attn_mask = torch.as_tensor([[0., float('-inf'), float('-inf')], [0., 0., float('-inf')], [float('-inf'), 0., 0.]])
'''Generate attention embedding'''
output, scores = attn(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
print("scores")
print(scores)
'''Create a dummy loss, only use the first embedding which is defined for all sequences'''
loss = output[0, :].sum()
print("loss")
print(loss)
'''Backwards pass and gradients'''
loss.backward()
print("grads")
for n, p in attn.named_parameters():
print(n, p.grad)
> scores
> tensor([[[1.0000, 0.0000, 0.0000],
> [0.4468, 0.5532, 0.0000],
> [0.0000, 0.5379, 0.4621]],
> [[1.0000, 0.0000, 0.0000],
> [1.0000, 0.0000, 0.0000],
> [ nan, nan, nan]]], grad_fn=<DivBackward0>)
> loss
> tensor(0.0040, grad_fn=<SumBackward0>)
> grads
> in_proj_weight tensor([[nan],
> [nan],
> [nan]])
> in_proj_bias tensor([nan, nan, nan])
> out_proj.weight tensor([[nan]])
> out_proj.bias tensor([2.])Expected behavior
Gradients should not be NaN
Environment
PyTorch version: 1.5.1
Is debug build: No
CUDA used to build PyTorch: None
OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[conda] blas 1.0 mkl
[conda] cpuonly 1.0 0 pytorch
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.18.5 py37ha1c710e_0
[conda] numpy-base 1.18.5 py37hde5b4d6_0
[conda] pytorch 1.5.1 py3.7_cpu_0 [cpuonly] pytorch
Also fails when using GPU.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @zhangguanheng66