KEMBAR78
Distributed 'gather' with the NCCL backend returns wrong results on noncontiguous tensors. · Issue #159548 · pytorch/pytorch · GitHub
Skip to content

Distributed 'gather' with the NCCL backend returns wrong results on noncontiguous tensors. #159548

@tiandeyu-cs

Description

@tiandeyu-cs

🐛 Describe the bug

If calling torch.distributed.gather on noncontiguous tensors with the NCCL backend, it will return wrong results.

This bug is similar to the bug #158902, but the script used in #158902 cannot show the existence this bug.

There is a python snippet for reproducing the bug.
main.py:

import argparse
import os
import torch
import torch.distributed as dist

args = None
def get_args():
    global args
    if args is not None:
        return args
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--world-size", "--world_size", type=int, required = True)
    parser.add_argument("--local-rank", "--local_rank", type=int, required = True)
    parser.add_argument("--master-addr", "--master_addr", type = str, default = "127.0.0.1")
    parser.add_argument("--master-port", "--master_port", type = int, default = 31111)
    parser.add_argument("--backend", type = str, choices = ["nccl", "gloo"], default = "nccl")
    args = parser.parse_args()
    return args

def run(local_rank, world_size):
    device = torch.device(f"cuda:{args.local_rank}")

    t = [[0, 0], 
         [1, 1]]
    t = torch.tensor(t).to(device)
    t = t[:, 0]
    
    arr = None
    if local_rank == 0:
        arr = [ torch.zeros_like(t) for _ in range(world_size) ]

    dist.gather(t, arr, dst = 0)
    
    if local_rank == 0:
        a = torch.cat(arr, dim = 0)
        a = a.detach().cpu().numpy()
        print(a)

def init_process(local_rank, world_size, backend, fn):
    args = get_args()
    os.environ["MASTER_ADDR"] = args.master_addr
    os.environ["MASTER_PORT"] = str(args.master_port)
    dist.init_process_group(backend, rank=local_rank, world_size=world_size)
    fn(local_rank, world_size)

def destroy_process():
    dist.destroy_process_group()

if __name__ == "__main__":
    args = get_args()
    init_process(args.local_rank, args.world_size, args.backend, run)
    destroy_process()

Here are some bash commands for reproducing the bug together with the command outputs on my computer.

$ python main.py --backend nccl --world-size 2 --local-rank 0 &
[1] 4820
$ python main.py --backend nccl --world-size 2 --local-rank 1
[W731 04:32:38.248105769 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W731 04:32:39.341066889 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W731 04:32:48.259092976 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W731 04:32:58.269091009 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[0 1 0 0]
[1]+  Done                    python main.py --backend nccl --world-size 2 --local-rank 0

The correct result is [0 1 0 1], whereas the program outputs [0 1 0 0], which shows the bug.

Versions

undisclosed

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: ncclProblems related to nccl supportoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions