KEMBAR78
Repeated torch.distributed.broadcast calls lead to OOM (with NCCL) · Issue #19219 · pytorch/pytorch · GitHub
Skip to content

Repeated torch.distributed.broadcast calls lead to OOM (with NCCL) #19219

@colesbury

Description

@colesbury

🐛 Bug

Repeated calls to torch.distributed.broadcast using the NCCL backend lead to out-of-memory errors, even when the "live" set of Tensors is small.

To Reproduce

import torch
import torch.distributed as dist

def worker(rank):
    for itr in range(1000):
        x = torch.randn(int(25 * 1024 * 1024), device='cuda')  # 25 MiB
        dist.broadcast(x, src=1, async_op=False)
        del x

def main(rank, init_method, world_size):
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", init_method, rank=rank, world_size=world_size)
    worker(rank)

if __name__ == '__main__':
    init_method = 'tcp://127.0.0.1:23123'
    world_size = 2
    torch.multiprocessing.spawn(main, (init_method, world_size), nprocs=world_size)
RuntimeError: CUDA out of memory. Tried to allocate 100.00 MiB (GPU 1; 11.93 GiB total capacity; 845.00 KiB already allocated; 56.00 MiB free; 11.52 GiB cached)

(On PyTorch master 4/12/2019 1c836e7)

Suggested fix part 1

In addition to freeing the available cached blocks, the CUDA caching allocator should free all blocks that are held because they're used in multiple streams. One way to achieve this is to loop over all the events in cuda_event, synchronize on the event and cudaFree the block once its event_count reaches zero. This is effectively a modified version of process_events.

cudaError_t cuda_malloc_retry(int device, void** devPtr, size_t size)
{
// Try cudaMalloc. If cudaMalloc fails, frees all non-split cached blocks
// and retries.
cudaError_t err = cudaMalloc(devPtr, size);
if (err != cudaSuccess) {
cudaGetLastError(); // reset the last CUDA error
free_cached_blocks(device);
err = cudaMalloc(devPtr, size);
if (err != cudaSuccess) {
return err;
}
}
return cudaSuccess;
}

void process_events()

Suggested fix part 2 (needs discussion)

The above fix will avoid the out of memory error, but the tight loop around dist.broadcast will still use all available memory before freeing all cached blocks. This is not desirable. To avoid this, I think torch.distributed.broadcast should operate solely on the caller's current stream by default, and only use a background stream if specified.

Metadata

Metadata

Assignees

Labels

high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generaloncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions