KEMBAR78
Adam raises RuntimeError("Adam does not support sparse gradients, ...") for masked gradients due to API discrepancy · Issue #104574 · pytorch/pytorch · GitHub
Skip to content

Adam raises RuntimeError("Adam does not support sparse gradients, ...") for masked gradients due to API discrepancy #104574

@tillahoffmann

Description

@tillahoffmann

🐛 Describe the bug

There is an API discrepancy between torch.masked.MaskedTensor.is_sparse and torch.Tensor.is_sparse, causing torch.optim.Adam to raise RuntimeError: Adam does not support sparse gradients, please consider SparseAdam instead for masked gradients even if the underlying data are dense.

In particular, for torch.Tensor, is_sparse is an attribute:

>>> x = torch.randn(4)
>>> x.__class__.is_sparse
<attribute 'is_sparse' of 'torch._C._TensorBase' objects>

However, for torch.masked.MaskedTensor, is_sparse is a method:

>>> y = torch.masked.as_masked_tensor(x, x < 0)
>>> y.__class__.is_sparse
<function torch.masked.maskedtensor.core.MaskedTensor.is_sparse(self)>

Now torch.optim.Adam checks for sparse gradients like so:

if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')

Callables are truth-y, and, consequently, the RuntimeError is raised. It would be possible to fix this discrepancy by adding a @property decorator in the torch.masked.MaskedTensor implementation here (but it would break backwards-compatibility).

# Update later to support more sparse layouts
def is_sparse(self):
return self.is_sparse_coo() or self.is_sparse_csr()

Versions

Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.4 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: version 3.25.2
Libc version: N/A

Python version: 3.10.9 (main, Feb 10 2023, 12:03:15) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-13.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy==1.4.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.25.0
[pip3] torch==2.0.1
[conda] Could not collect

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar @george-qi

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: masked operatorsMasked operationsmodule: optimizerRelated to torch.optimmodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions