-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
torch.multinomial occasionally samples elements with zero weight. This should never happen.
To Reproduce
I've been unable to reproduce this issue with randomly generated weights, so I've included a particular value of weights from my application that triggers this behavior:
wget https://cs.stanford.edu/people/jcjohns/weights.ptThese weights are all nonnegative (but contain a lot of zeros), have a nonzero sum, and contain no NaNs or Infs.
import torch
torch.manual_seed(1)
weights = torch.load('weights.pt')
N, S = weights.shape[0], 4096
num_trials = 100
for trial in range(1, num_trials + 1):
print('Starting trial %d / %d' % (trial, num_trials))
weights[weights < 0] = 0.0
samples = weights.multinomial(S, replacement=True)
sampled_weights = weights[samples]
assert sampled_weights.min() > 0I fail the assertion on trial 6.
Environment
PyTorch version: 1.0.0.dev20181112
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 396.51
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] Could not collect
[conda] pytorch 0.4.1 py37_py36_py35_py27__9.0.176_7.1.2_2 pytorch
[conda] pytorch-nightly 1.0.0.dev20181112 py3.7_cuda9.0.176_cudnn7.1.2_0 pytorch
[conda] torchvision 0.2.1
[conda] torchvision 0.2.1 py37_1 pytorch