KEMBAR78
MultiheadAttention set(-inf) cause 'Nan' in loss computing · Issue #40932 · pytorch/pytorch · GitHub
Skip to content

MultiheadAttention set(-inf) cause 'Nan' in loss computing #40932

@HelenaHlz

Description

@HelenaHlz

🐛 Bug

I plan to reimplement a transformer variant model. I import MultiheadAttention from torch.nn.modules.activation.
In the encoder part, to be specific, the self multi-head attention part, if the whole input is padded, it means the key_padding_mask parameter is full of True.

"When the value is True, the corresponding value on the attention layer will be filled with -inf."

This setting leads to NaN in model parameters and raises ValueError("nan loss encountered").

To Reproduce

Steps to reproduce the behavior:

  1. Initialize a MultiheadAttention.
    self.self_attn= MultiheadAttention(embed_dim=embed_dim,num_heads=nhead,dropout=dropout)
  2. In forward() function.
    src, attn = self.self_attn(src,src,src,attn_mask=src_mask,
    key_padding_mask=src_key_padding_mask)
  3. Then pass an x. The vector src_key_padding_mask is all implemented True. The original sentence in src is <pad> * max_seq_length.

I use allennlp, this raises "ValueError: nan loss encountered".
I founded that one example in a batch is full of <pad>, which causes this issue.

Expected behavior

I found some descriptions of almost the same problem in fairseq.
fairseq
line 103
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters

Hope u can learn from their practices.

Environment

  • PyTorch Version (e.g., 1.0): 1.5.1
  • OS (e.g., Linux): Ubuntu 16.04.6 LTS
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.7
  • CUDA/cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.4
  • GPU models and configuration: TITAN Xp
  • Any other relevant information:
    numpy==1.18.5

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: NaNs and InfsProblems related to NaN and Inf handling in floating pointmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions