-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Bug
torch.nn.functional.affine_grid on GPU crashes when batch size >= 256*256
To Reproduce
import torch
batch_size = 256*256
transform_parameters = torch.cuda.FloatTensor([[1,0,0],[0,1,0]])
transform_parameters = torch.stack([transform_parameters] * batch_size, 0).contiguous()
resampling_grids = torch.nn.functional.affine_grid(transform_parameters, torch.Size((batch_size, 1, 2, 2)))
print(resampling_grids.size())
crashes with the following error:
Traceback (most recent call last):
File "bug.py", line 5, in <module>
resampling_grids = torch.nn.functional.affine_grid(transform_parameters, torch.Size((batch_size, 1, 2, 2)))
File "/home/aosokin/local/software/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py", line 2615, in affine_grid
return vision.affine_grid_generator(theta, size)
File "/home/aosokin/local/software/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/_functions/vision.py", line 10, in affine_grid_generator
ret = torch.cudnn_affine_grid_generator(theta, N, C, H, W)
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
If I set batch_size to 256*256 - 1 or do operation on a CPU everything works fine.
Behavior is very similar to the torch.inverse bug #13276 , but with a different function and a different error message.
Expected behavior
The code should print torch.Size([65536, 2, 2, 2])
Environment
PyTorch version: 1.0.0 (checked with pytorch-nightly of 24.01.2019)
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 410.79
cuDNN version: 7401 (installed by pytorch itself)
Versions of relevant libraries:
[pip3] maskrcnn-benchmark (0.1, /media/aosokin/kingston2tb/software/pytorch/maskrcnn-benchmark)
[pip3] numpy (1.14.0)
[pip3] numpydoc (0.7.0)
[pip3] torch (1.0.0a0+db5d313)
[pip3] torchfile (0.1.0)
[pip3] torchvision (0.2.1)
[conda] blas 1.0 mkl
[conda] cuda100 1.0 0 pytorch
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.0 py3.7_cuda10.0.130_cudnn7.4.1_1 [cuda100] pytorch
[conda] torchvision 0.2.1 py_2 pytorch