-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: paddingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
It seems that when torch.stft(return_complex=True)
is followed by torch.abs
, gradgradcheck
fails, but individually, they do not fail.
To Reproduce
script
Steps to reproduce the behavior:
import torch
from torch.autograd import gradgradcheck
def stft_with_abs(tensor):
tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)
tensor = tensor.abs()
return tensor
def abs_(tensor):
return tensor.abs()
def stft(tensor):
return torch.stft(tensor, n_fft=256, return_complex=True)
def test_stft_with_abs():
for i in range(100):
print(i, '\r', end='')
tensor = torch.randn([2, 250])
tensor.requires_grad = True
tensor = tensor.to(dtype=torch.float64, device='cuda')
assert gradgradcheck(stft_with_abs, [tensor])
def test_stft_only():
for i in range(100):
print(i, '\r', end='')
tensor = torch.randn([2, 250])
tensor.requires_grad = True
tensor = tensor.to(dtype=torch.float64, device='cuda')
assert gradgradcheck(stft, [tensor])
def test_abs_only():
for i in range(100):
print(i, '\r', end='')
tensor = torch.randn([2, 250])
tensor = tensor.to(dtype=torch.float64, device='cuda')
tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)
tensor.requires_grad = True
assert gradgradcheck(abs_, [tensor])
# test_stft_only() # does not fail
# test_abs_only() # does not fail
test_stft_with_abs()
The test_stft_with_abs()
fails with the following message;
Traceback (most recent call last):
File "foo.py", line 63, in <module>
test_stft_with_abs()
File "foo.py", line 39, in test_stft_with_abs
assert gradgradcheck(stft_with_abs, [tensor])
File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 674, in gradgradcheck
return gradcheck(
File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 479, in gradcheck
return not_reentrant_error()
File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 476, in not_reentrant_error
return fail_test(error_msg)
File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 367, in fail_test
raise RuntimeError(msg)
RuntimeError: Backward is not reentrant, i.e., running backward with same input and grad_output multiple times gives different values, although analytical gradient matches numerical gradient. The tolerance for nondeterminism was 0.0.
I also tried with return_complex=False
but gradgradcheck
did not fail.
Expected behavior
gradgradcheck
should pass for stft+abs
case
Environment
Collecting environment information...
PyTorch version: 1.9.0.dev20210316
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.0.dev20210316
[pip3] torchaudio==0.9.0a0+ba61c9b
[pip3] torchtext==0.9.0a0+c072ba6
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] magma-cuda101 2.5.2 1 pytorch
[conda] mkl 2020.2 256
[conda] mkl-include 2020.4 h726a3e6_304 conda-forge
[conda] mkl-service 2.3.0 py38he904b0f_0
[conda] mkl_fft 1.3.0 py38h54f3939_0
[conda] mkl_random 1.1.1 py38h0573a6f_0
[conda] numpy 1.19.2 py38h54aff64_0
[conda] numpy-base 1.19.2 py38hfa32c7d_0
[conda] pytorch 1.9.0.dev20210316 py3.8_cuda10.1_cudnn7.6.3_0 pytorch-nightly
[conda] pytorch-sphinx-theme 0.0.24 dev_0 <develop>
[conda] torch 1.7.1 pypi_0 pypi
[conda] torchaudio 0.9.0a0+ba61c9b dev_0 <develop>
[conda] torchtext 0.9.0a0+c072ba6 dev_0 <develop>
Additional context
In pytorch/audio#1340, I was adding test to run gradgradcheck
on torchaudio.transforms.Spectrogram
. The CI reported un-deterministic error.
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @ngimel @mruberry @kurtamohler
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: paddingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module