KEMBAR78
[FR] Safe softmax · Issue #55056 · pytorch/pytorch · GitHub
Skip to content

[FR] Safe softmax #55056

@jbschlosser

Description

@jbschlosser

🚀 Feature

Expand softmax to support "safe softmax" behavior (i.e. output 0 instead of NaN when called with all -inf inputs).

Motivation

When softmax is called with all -inf inputs, the output is NaN due to computation of exp(-inf) = 0 leading to 0/0. In certain cases, an output of 0 may be preferred instead.

For example, due to masking in MHA, a batch may include a sentence that is "masked out" / consists entirely of padding (masked out by setting to -inf). This can happen as a result of utilizing attn_mask and key_padding_mask. While this case will currently result in NaNs, leading to training divergence, it is often preferred to effectively ignore these cases so training can continue. See #41508 and #40932 for more info.

Pitch

Add a new, optional eps parameter with default None to the various softmax forms:

from torch import nn
from torch.nn import functional as F

m = nn.Softmax(dim=-1, eps=1e-5)
x = torch.ones(5) * float('-inf')
assert torch.all(m(x) == 0) and torch.all(F.softmax(x, eps=1e-5) == 0)

If specified, eps will be included in the denominator of the softmax computation, resulting in 0/(0+eps) = 0 instead of 0/0 = NaN for the all -inf input case. See #41508 (comment).

Alternatives

Leave as-is; calling softmax with all -inf inputs will unconditionally result in NaNs. It is left to users to ensure this doesn't happen if it's undesired, and failure to handle this correctly will still be considered "user error".

For the MHA problem of NaNs resulting from padded-out inputs, a workaround that has been employed requires hacking MHA internals to apply masked_fill(). Additionally, explicitly checking for the padded-out input case (e.g. to warn) can be expensive.

cc @albanD @mruberry @jbschlosser

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions