KEMBAR78
Weird behavior of cuda::cub::sort_pairs · Issue #63427 · pytorch/pytorch · GitHub
Skip to content

Weird behavior of cuda::cub::sort_pairs #63427

@ymwangg

Description

@ymwangg

🐛 Bug

I saw a dramatic increase of memory requirement after #62495.
The maximum batch_size of the following script can tolerate decreases from 13 to 6 after this PR on Nvidia T4 GPU.

from transformers import BertForMaskedLM
import torch

batch_size = 7
device = torch.device("cuda:0")

model = BertForMaskedLM.from_pretrained("bert-base-uncased")
model.to(device)
model.train()

input_ids = torch.ones((batch_size, 512)).to(torch.int64).to(device)
attention_mask = torch.ones((batch_size, 512)).to(torch.int64).to(device)
labels = torch.ones((batch_size, 512)).to(torch.int64).to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
print(loss)

If batch_size > 6, I got the following error:

Traceback (most recent call last):
  File "debug.py", line 16, in <module>
    loss.backward()
  File "/home/ubuntu/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 313, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/ubuntu/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: unique_by_key: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered

After some investigation, it looks like the cuda::cub::sort_pairs in embedding_dense_backward_cuda here https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Embedding.cu#L277 is returning garbage data in sorted_indices for batch_size > 6.

To Reproduce

Steps to reproduce the behavior:

  1. Run the above script and modify the batch size.

Expected behavior

The above script should run with batch_size = 13.

Environment

PyTorch version: 1.10.0a0+gitc4aeeca
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.21.0
Libc version: glibc-2.27

Python version: 3.8.8 (default, Apr 13 2021, 19:58:26)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-1054-aws-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.0.221
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 450.142.00
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.1
[pip3] numpydoc==1.1.0
[pip3] torch==1.10.0a0+gitb0396e3
[pip3] torch-xla==1.10
[conda] blas                      1.0                         mkl  
[conda] mkl                       2021.2.0           h06a4308_296  
[conda] mkl-service               2.3.0            py38h27cfd23_1  
[conda] mkl_fft                   1.3.0            py38h42c9631_2  
[conda] mkl_random                1.2.1            py38ha9443f7_2  
[conda] mypy_extensions           0.4.3                    py38_0  
[conda] numpy                     1.20.1           py38h93e21f0_0  
[conda] numpy-base                1.20.1           py38h7d8b39e_0  
[conda] numpydoc                  1.1.0              pyhd3eb1b0_1  
[conda] torch                     1.10.0a0+gitb0396e3          pypi_0    pypi
[conda] torch-xla                 1.10                     pypi_0    pypi

Additional context

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ngimel @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: sorting and selectiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions