-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 Feature
Expand softmax to support "safe softmax" behavior (i.e. output 0 instead of NaN when called with all -inf inputs).
Motivation
When softmax is called with all -inf inputs, the output is NaN due to computation of exp(-inf) = 0 leading to 0/0. In certain cases, an output of 0 may be preferred instead.
For example, due to masking in MHA, a batch may include a sentence that is "masked out" / consists entirely of padding (masked out by setting to -inf). This can happen as a result of utilizing attn_mask and key_padding_mask. While this case will currently result in NaNs, leading to training divergence, it is often preferred to effectively ignore these cases so training can continue. See #41508 and #40932 for more info.
Pitch
Add a new, optional eps parameter with default None to the various softmax forms:
from torch import nn
from torch.nn import functional as F
m = nn.Softmax(dim=-1, eps=1e-5)
x = torch.ones(5) * float('-inf')
assert torch.all(m(x) == 0) and torch.all(F.softmax(x, eps=1e-5) == 0)If specified, eps will be included in the denominator of the softmax computation, resulting in 0/(0+eps) = 0 instead of 0/0 = NaN for the all -inf input case. See #41508 (comment).
Alternatives
Leave as-is; calling softmax with all -inf inputs will unconditionally result in NaNs. It is left to users to ensure this doesn't happen if it's undesired, and failure to handle this correctly will still be considered "user error".
For the MHA problem of NaNs resulting from padded-out inputs, a workaround that has been employed requires hacking MHA internals to apply masked_fill(). Additionally, explicitly checking for the padded-out input case (e.g. to warn) can be expensive.