-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🐛 Describe the bug
Hi there!
I am trying to manually reproduce this scheme, which corresponds to a model row parallelism paradigm:

Therefore I have written a small snippet of code to reproduce the scheme above, the assertion passes for small batch_size, hidden_dim and output_dim values (for example: batch_size = 1 \ hidden_dim = 4 \ output_dim = 2) but strangely the assertion does not pass anymore for large values (specifically 10, 40, 20). What could explain this behavior?
Code snippet (tried on google colab):
import torch
import torch.nn as nn
import torch.nn.functional as F
batch_size = 10
hidden_dim = 40
output_dim = 20
sliced_dim = hidden_dim//2
dummy_mlp = nn.Linear(hidden_dim, output_dim, bias=False)
parameters = torch.nn.Parameter(torch.randn(output_dim, hidden_dim))
dummy_mlp.weight = parameters
dummy_input = torch.randn(batch_size, hidden_dim)
sliced_input_1, sliced_input_2 = torch.split(dummy_input, sliced_dim, dim=-1)
sliced_output_1 = F.linear(sliced_input_1, dummy_mlp.weight[:, :sliced_dim])
sliced_output_2 = F.linear(sliced_input_2, dummy_mlp.weight[:, sliced_dim:])
final_output = dummy_mlp(dummy_input)
assert torch.equal(final_output, sliced_output_1 + sliced_output_2)
Thank you very much in advance for your help!!
Versions
Collecting environment information...
PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26
Python version: 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.0+cu111
[pip3] torchaudio==0.10.0+cu111
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1+cu111
[conda] Could not collect