KEMBAR78
stft + abs is non-deterministic in backward path · Issue #54093 · pytorch/pytorch · GitHub
Skip to content

stft + abs is non-deterministic in backward path #54093

@mthrok

Description

@mthrok

🐛 Bug

It seems that when torch.stft(return_complex=True) is followed by torch.abs, gradgradcheck fails, but individually, they do not fail.

To Reproduce

script

Steps to reproduce the behavior:

import torch

from torch.autograd import gradgradcheck


def stft_with_abs(tensor):
    tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)
    tensor = tensor.abs()
    return tensor


def abs_(tensor):
    return tensor.abs()


def stft(tensor):
    return torch.stft(tensor, n_fft=256, return_complex=True)


def test_stft_with_abs():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor.requires_grad = True

        tensor = tensor.to(dtype=torch.float64, device='cuda')
        assert gradgradcheck(stft_with_abs, [tensor])


def test_stft_only():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor.requires_grad = True

        tensor = tensor.to(dtype=torch.float64, device='cuda')
        assert gradgradcheck(stft, [tensor])


def test_abs_only():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor = tensor.to(dtype=torch.float64, device='cuda')
        tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)

        tensor.requires_grad = True
        assert gradgradcheck(abs_, [tensor])


# test_stft_only()  # does not fail
# test_abs_only()  # does not fail
test_stft_with_abs()

The test_stft_with_abs() fails with the following message;

Traceback (most recent call last):
  File "foo.py", line 63, in <module>
    test_stft_with_abs()
  File "foo.py", line 39, in test_stft_with_abs
    assert gradgradcheck(stft_with_abs, [tensor])
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 674, in gradgradcheck
    return gradcheck(
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 479, in gradcheck
    return not_reentrant_error()
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 476, in not_reentrant_error
    return fail_test(error_msg)
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 367, in fail_test
    raise RuntimeError(msg)
RuntimeError: Backward is not reentrant, i.e., running backward with same                         input and grad_output multiple times gives different values,                         although analytical gradient matches numerical gradient.                         The tolerance for nondeterminism was 0.0.

I also tried with return_complex=False but gradgradcheck did not fail.

Expected behavior

gradgradcheck should pass for stft+abs case

Environment

Collecting environment information...
PyTorch version: 1.9.0.dev20210316
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.0.dev20210316
[pip3] torchaudio==0.9.0a0+ba61c9b
[pip3] torchtext==0.9.0a0+c072ba6
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] magma-cuda101             2.5.2                         1    pytorch
[conda] mkl                       2020.2                      256
[conda] mkl-include               2020.4             h726a3e6_304    conda-forge
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.3.0            py38h54f3939_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.19.2           py38h54aff64_0
[conda] numpy-base                1.19.2           py38hfa32c7d_0
[conda] pytorch                   1.9.0.dev20210316 py3.8_cuda10.1_cudnn7.6.3_0    pytorch-nightly
[conda] pytorch-sphinx-theme      0.0.24                    dev_0    <develop>
[conda] torch                     1.7.1                    pypi_0    pypi
[conda] torchaudio                0.9.0a0+ba61c9b           dev_0    <develop>
[conda] torchtext                 0.9.0a0+c072ba6           dev_0    <develop>

Additional context

In pytorch/audio#1340, I was adding test to run gradgradcheck on torchaudio.transforms.Spectrogram. The CI reported un-deterministic error.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @ngimel @mruberry @kurtamohler

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: paddingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions