KEMBAR78
Multiprocess DataLoader doesn't work with sparse tensor as it'll try to access the underlying storage · Issue #106837 · pytorch/pytorch · GitHub
Skip to content

Multiprocess DataLoader doesn't work with sparse tensor as it'll try to access the underlying storage #106837

@yundai424

Description

@yundai424

🐛 Describe the bug

Summary

When using DataLoader with multiprocess loading to load a dataset with sparse tensor elements, it'll try to access the underlying storage of the tensor, but sparse tensor (COO, CSF etc) doesn't support accessing storage.

I've put the minimal reproduction sample in this colab notebook: https://colab.research.google.com/drive/16q_tzyUz5ylZSCcpzhJ52pxVSUpMZX-M#scrollTo=o0KeaWnVz9Hm&uniqifier=1

Case 1: default collation

When using default collate (auto_collate=True here), to collate on a sparse tensor, it'll attempt to access the elem._typed_storage here, thus hitting error:


NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 160, in collate_tensor_fn
    storage = elem._typed_storage()._new_shared(numel, device=elem.device)
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 238, in _typed_storage
    untyped_storage = self.untyped_storage()
NotImplementedError: Cannot access storage of SparseTensorImpl

Case 2: Manual collation

Without auto collation, (set batch_size=None so that it'll use default_convert method, OR provide a collate_fn) we get around the issue in default_collate in case 1; But later on when the worker process is feeding into the worker_result_queue the loaded data, it'll again attemp to access the underlying storage, thus hitting

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/reductions.py", line 152, in reduce_tensor
    storage = tensor._typed_storage()
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 238, in _typed_storage
    untyped_storage = self.untyped_storage()
NotImplementedError: Cannot access storage of SparseTensorImpl
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/reductions.py", line 152, in reduce_tensor
    storage = tensor._typed_storage()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 238, in _typed_storage
    untyped_storage = self.untyped_storage()
NotImplementedError: Cannot access storage of SparseTensorImpl

So that anyway we cannot do multiprocess loading with sparse tensor.

Versions

PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.25.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.109+-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          2
On-line CPU(s) list:             0,1
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                      6
Model:                           79
Thread(s) per core:              2
Core(s) per socket:              1
Socket(s):                       1
Stepping:                        0
BogoMIPS:                        4399.99
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       32 KiB (1 instance)
L1i cache:                       32 KiB (1 instance)
L2 cache:                        256 KiB (1 instance)
L3 cache:                        55 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0,1
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable; SMT Host state unknown
Vulnerability Meltdown:          Vulnerable
Vulnerability Mmio stale data:   Vulnerable
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchdata==0.6.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.2+cu118
[pip3] triton==2.0.0
[conda] Could not collect

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @ssnl @VitalyFedyunin @ejguan @dzhulgakov

Metadata

Metadata

Assignees

Labels

module: dataloaderRelated to torch.utils.data.DataLoader and Samplermodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions