KEMBAR78
MultiheadAttention returns NaNs when need_weights=False for long sequences with a mask that ignores old tokens · Issue #127055 · pytorch/pytorch · GitHub
Skip to content

MultiheadAttention returns NaNs when need_weights=False for long sequences with a mask that ignores old tokens #127055

@twoertwein

Description

@twoertwein

🐛 Describe the bug

It works as expected for shorter sequences and when all past tokens are allowed.

import torch

model = torch.nn.MultiheadAttention(embed_dim=2, num_heads=1)

n = 600
sequence = torch.ones(n, 2)

# do not attend to the future and very old tokens
full = torch.full((n, n), float("-inf"))
mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10)

print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=False)[0])
#tensor([[0.0519, 0.1435],
#        [0.0519, 0.1435],
#        [0.0519, 0.1435],
#        ...,
#        [   nan,    nan],
#        [   nan,    nan],
#        [   nan,    nan]], grad_fn=<SqueezeBackward1>)
print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=True)[0])
#tensor([[0.0519, 0.1435],
#        [0.0519, 0.1435],
#        [0.0519, 0.1435],
#        ...,
#        [0.0519, 0.1435],
#        [0.0519, 0.1435],
#        [0.0519, 0.1435]], grad_fn=<SqueezeBackward1>)

Versions

Collecting environment information...
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] onnx==1.16.0
[pip3] tf2onnx==1.16.1
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi

cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

Metadata

Metadata

Assignees

Labels

intelThis tag is for PR from Intelmodule: cpuCPU specific problem (e.g., perf, algorithm)

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions