KEMBAR78
NotImplementedError: Cannot access storage of SparseTensorImpl · Issue #58755 · pytorch/pytorch · GitHub
Skip to content

NotImplementedError: Cannot access storage of SparseTensorImpl #58755

@gcramer23

Description

@gcramer23

🐛 Bug

when using cuda rpc and returning a sparse tensor to update gradients during the .backward() call NotImplementedError: Cannot access storage of SparseTensorImpl is thrown.

To Reproduce

Steps to reproduce the behavior:

run the program example below with cuda_rpc=True in if __name__ == "__main__":

import os

import torch
import torch.distributed as c10d
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.rpc import TensorPipeRpcBackendOptions
from torch.nn.parallel import DistributedDataParallel as DDP


def sparse_tensor_to_rpc_format(sparse_tensor):
    sparse_tensor = sparse_tensor.coalesce()
    return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]


def sparse_rpc_format_to_tensor(sparse_rpc_format):
    return torch.sparse_coo_tensor(
        sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
    ).coalesce()


class Server:

    @staticmethod
    @rpc.functions.async_execution
    def identity(tensor):
        fut = torch.futures.Future()
        fut.set_result(tensor)
        return fut


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.EmbeddingBag(10, 10, sparse=True)

    def forward(self, x):
        return self.embedding(x)


def run_trainer(rank, rref, cuda_rpc):

    def basic_hook(state, bucket):

        def callback(fut):
            tensor = fut.wait()
            if type(tensor) is list:
                tensor = sparse_rpc_format_to_tensor(tensor)
            if not cuda_rpc:
                tensor = tensor.cuda(rank)
            return [tensor]

        tensor = bucket.get_tensor()
        if not cuda_rpc:
            tensor = tensor.cpu()
        if tensor.is_sparse:
            tensor = sparse_tensor_to_rpc_format(tensor)
        fut = rref.rpc_async().identity(tensor).then(callback)
        return fut

    model = Model().cuda(rank)
    store = c10d.FileStore("/tmp/tmpn_k_8so02", 1)
    process_group = c10d.ProcessGroupGloo(store, rank, 1)
    ddp_model = DDP(model, device_ids=[rank], process_group=process_group)
    ddp_model.register_comm_hook(None, basic_hook)
    loss_fn = nn.MSELoss()
    input = torch.randint(5, (10, 10)).cuda(rank)
    loss_fn(ddp_model(input), input.to(torch.float)).backward()


def run_test(rank, cuda_rpc):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29501'
    opts = TensorPipeRpcBackendOptions()
    if rank == 2:
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=3,
            rpc_backend_options=opts
        )
        rref = rpc.remote("server", Server)
        rpc.rpc_async(
            "trainer",
            run_trainer,
            args=(0, rref, cuda_rpc,)
        )
    elif rank == 0:
        if cuda_rpc:
            opts.set_device_map("server", {rank: 1})
        rpc.init_rpc(
            "trainer",
            rank=rank,
            world_size=3,
            rpc_backend_options=opts
        )
    else:
        rpc.init_rpc(
            "server",
            rank=rank,
            world_size=3,
            rpc_backend_options=opts
        )
    rpc.shutdown()


if __name__ == "__main__":
    cuda_rpc = True
    mp.spawn(
        run_test,
        nprocs=3,
        args=(cuda_rpc,),
        join=True
    )
On WorkerInfo(id=0, name=trainer):
NotImplementedError('Cannot access storage of SparseTensorImpl',)
Traceback (most recent call last):
  File "/fsx/users/gcramer/work/pytorch/torch/distributed/rpc/internal.py", line 210, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
  File "/fsx/users/gcramer/work/pytorch/benchmarks/distributed/rpc/parameter_server/test.py", line 71, in run_trainer
    loss_fn(out, input.to(torch.float)).backward()
  File "/fsx/users/gcramer/work/pytorch/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/fsx/users/gcramer/work/pytorch/torch/autograd/__init__.py", line 149, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
NotImplementedError: Cannot access storage of SparseTensorImpl

Expected behavior

.backward() can access storage of SparseTensorImpl and gradients are updated

Environment

Collecting environment information...
PyTorch version: 1.9.0a0+git7bc46bc
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.6

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB

Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.9.0a0+git7bc46bc
[conda] magma-cuda110             2.5.2                         1    pytorch
[conda] mkl                       2021.2.0           h726a3e6_389    conda-forge
[conda] mkl-include               2021.2.0           h726a3e6_389    conda-forge
[conda] numpy                     1.19.2           py36h6163131_0  
[conda] numpy-base                1.19.2           py36h75fe3a5_0  
[conda] torch                     1.9.0a0+git7bc46bc           dev_0    <develop>

Additional context

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @aazzolini @rohan-varma @jjlilley @osalpekar @jiayisuse @mrzzd @agolynski @SciPioneer @H-Huang @cbalioglu @gcramer23

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizeroncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions