-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
Motivations:
- This feature proposes a solution from the PyTorch Distributed team to the following use case: customers will have their users code "just work" on machines that have GPUs as well as machines that lack GPUs, without changing the c10d backend specification.
- The proposed solution also aims at better supporting customers with both GPU collectives and CPU collectives.
- Existing library specification by users will be honored and will not require change.
- Improve PyTorch Distributed collective extensibility and support for third-party extensions.
Proposal:
PyTorch Distributed Team proposes to dispatch collective operations to meet the above requirements.
Details:
The c10d library of PyTorch has been relying on the call of init_process_group(backend=’xyz’)
to know which backend to prepare for the user.
While this has provided users a choice when there were backends with similar functionalities (such as NCCL and Gloo for CUDA tensors) – which is less needed now – it also ties user’s later operations to the capability of the specified backend. For example, if the user specifies ‘nccl’
as the backend, it is expected that all later collective operations are on CUDA tensors.
To make the life of backend developers easier while supporting a growing diversity in collective needs, there is a need for PyTorch to dispatch collectives of given tensor types to the correct backend.
The infrastructure for supporting dispatchability already exists today. PyTorch core has a dispatcher internally that figures out which kernel implementation to call for a given tensor type. For example, it can switch between the CPU and CUDA implementations of a torch.add
operator, depending on the torch.device attribute of the input tensor ('cpu'
or 'cuda'
). While this capability is mainly built for ATen operators today, it can be extended to c10d operations, and there has been effort in achieving that.
The dispatch capability makes it possible for PyTorch to have a default solution for a tensor type, rather than fully relying on the user to get such knowledge via the init_process_group(backend)
call. We expect that the backend argument in this API would become optional after PyTorch has the dispatching capability. Note that this does not break backward compatibility with respect to current usage of this API.
Users may still use the backend argument to specify their preference. The difference is that the effectiveness of syntax backend='xyz'
would be limited to the tensor type(s) backend xyz can support. For example, backend='nccl'
would be understood by PyTorch as: “for CUDA tensors (rather than all tensors), the user's preferred backend is NCCL.” This would leave the non-CUDA preference floating. If the user later passes in a CPU tensor, PyTorch can still use its default preferred solution for CPU. This helps us achieve 0 lines of code change in the first use case identified in the Motivation section.
For design completeness, users can specify multiple backends in a single command. This is only for usability: backend='cuda:backend1,cpu:backend2'
.
Example
Here is a basic example we would be able to support. Note: this is just a sample and the specifics are subject to change.
import torch
import torch.distributed as dist
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
if __name__ == "__main__":
dist.init_process_group() # The backend argument is optional
# We will use the same process group for both collectives
t = torch.ones(1)
output = [torch.zeros(1)]
dist.all_gather(output, t) # Called with CPU backend (default GLOO)
print(output)
t_cuda = torch.ones(1, device="cuda:0")
cuda_output = [torch.zeros(1, device="cuda:0")]
dist.all_gather(cuda_output, t_cuda) # Called with CUDA backend (default NCCL)
print(cuda_output)
# Output:
# >> [tensor([1.])]
# >> [tensor([1.], device='cuda:0')]
Timeline
We are targetting this feature for the next release after 1.13. Please follow along with this issue as changes will be checked into nightly trunk and the feature may be ready before the official release.
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @kwen2501