-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Open
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
In pytorch/audio#1420, we are adding autograd check to torchaudio.functional.phase_vocoder and we noticed that it fails autograd test when rate=0.7, while it works for rate=0.8 ~ 1.3
Note that phase_vocoder shrinks spectrogram along time axis, so this might be a legit result.
To Reproduce
The following is the minimized version of torchaudio.functional.phase_vocoder.
import torch
from torch.autograd import gradgradcheck
def func(spectrogram, rate):
time_steps = torch.arange(
0,
spectrogram.size(-1),
rate,
device=spectrogram.device,
dtype=torch.real(spectrogram).dtype)
spectrogram = torch.nn.functional.pad(spectrogram, [0, 2])
spectrogram_0 = spectrogram.index_select(-1, time_steps.long())
spectrogram_1 = spectrogram.index_select(-1, (time_steps + 1).long())
angle_0 = spectrogram_0.angle()
angle_1 = spectrogram_1.angle()
phase = angle_1 - angle_0
return phase
def main():
n_fft = 400
for rate in [0.8, 0.7]:
print('testing:', rate)
spectrogram = torch.stft(torch.randn(2, 256), n_fft=n_fft, return_complex=True).to(torch.complex128)
spectrogram.requires_grad = True
assert gradgradcheck(func, [spectrogram, rate])
main()$ python foo.py
Traceback (most recent call last):
File "foo.py", line 35, in <module>
main()
File "foo.py", line 32, in main
assert gradgradcheck(func, [spectrogram, rate])
File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 732, in gradgradcheck
return gradcheck(
File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 616, in gradcheck
return fail_test(get_notallclose_msg(analytical_from_imag_grad_out[j],
File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 577, in fail_test
raise RuntimeError(msg)
RuntimeError: Gradients failed to compare equal for grad output = 1j. Jacobian mismatch for output 0 with respect to input 1,
numerical:tensor([[-0.0214, 0.1287, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0214, 0.1287, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.1287, 0.1012, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0707, -0.1014],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.1014],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.1014]],
dtype=torch.float64)
analytical:tensor([[-0.0214, 0.1287, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0214, 0.1287, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, -0.1287, 0.1012, ..., 0.0000, 0.0000, 0.0000],
...,
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0707, -0.1014],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
dtype=torch.float64)
Expected behavior
It should pass autograd test
Environment
Collecting environment information...
PyTorch version: 1.9.0.dev20210405
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.0.dev20210405
[pip3] torchaudio==0.9.0a0+9a0e70e
[pip3] torchtext==0.9.0a0+c072ba6
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] magma-cuda101 2.5.2 1 pytorch
[conda] mkl 2020.2 256
[conda] mkl-include 2020.4 h726a3e6_304 conda-forge
[conda] mkl-service 2.3.0 py38he904b0f_0
[conda] mkl_fft 1.3.0 py38h54f3939_0
[conda] mkl_random 1.1.1 py38h0573a6f_0
[conda] numpy 1.19.2 py38h54aff64_0
[conda] numpy-base 1.19.2 py38hfa32c7d_0
[conda] pytorch 1.9.0.dev20210405 py3.8_cuda10.1_cudnn7.6.3_0 pytorch-nightly
[conda] pytorch-sphinx-theme 0.0.24 dev_0 <develop>
[conda] torch 1.7.1 pypi_0 pypi
[conda] torchaudio 0.9.0a0+9a0e70e dev_0 <develop>
[conda] torchtext 0.9.0a0+c072ba6 dev_0 <develop>
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module