-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
Gumbel distributions with relatively modest parameters return -inf for log_prob in situations where the log_prob is nowhere near exceeding the range of float32.
To Reproduce
Steps to reproduce the behavior:
-
$ pip3 install pytorch
Collecting torch
Downloading https://files.pythonhosted.org/packages/7e/60/66415660aa46b23b5e1b72bc762e816736ce8d7260213e22365af51e8f9c/torch-1.0.0-cp36-cp36m-manylinux1_x86_64.whl (591.8MB)
100% |████████████████████████████████| 591.8MB 22kB/s
tcmalloc: large alloc 1073750016 bytes == 0x621c2000 @ 0x7fceac4752a4 0x591a07 0x5b5d56 0x502e9a 0x506859 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x504c28 0x502540 0x502f3d 0x507641
Installing collected packages: torch
Successfully installed torch-1.0.0 -
In python3, run:
import torch
g = torch.distributions.Gumbel(loc=0.0, scale=1.0)
print(g.log_prob(-5.0))
The result is
tensor(-inf).
Expected behavior
However, the expected value is a number between -143 and -144, namely,
5.0 - exp(5.0).
(This correct value can be determined by using the closed form for the Gumbel distribution's PDF, and then taking the log).
In the ideal case, I expect that the output for g.log_prob(x) should be equal to:
-torch.log(g.scale) + (g.loc - x)/g.scale - torch.exp((g.loc - x)/g.scale)
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: No
CUDA runtime version: 9.2.148
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
Additional context
This came up in the process of computing the kl_divergence of one Gumbel distribution against another using both the closed form
pytorch/torch/distributions/kl.py
Lines 266 to 273 in d86cc3e
| def _kl_gumbel_gumbel(p, q): | |
| ct1 = p.scale / q.scale | |
| ct2 = q.loc / q.scale | |
| ct3 = p.loc / q.scale | |
| t1 = -ct1.log() - ct2 + ct3 | |
| t2 = ct1 * _euler_gamma | |
| t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) | |
| return t1 + t2 + t3 - (1 + _euler_gamma) |