-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
Together with @0mp, @VirrageS andy @jytug we're developing a torch.distributed
package for PyTorch. All work is done in a fork on a thd
branch (we didn't want to make a lot of unnecessary noise in the main repo). We're creating this issue, so we can gather feedback on our API designs from all you guys.
We plan to make the package have two modes. The user has to choose one of them as part of the initialisation.
Process group mode
This is very similar to the API defined in MPI. We assume all processes are equal, assign them ranks and later on, allow them to use a well known set of communication collectives like reduce
, broadcast
, allReduce
, gather
, scatter
, etc.
Example:
import torch.distributed
torch.distributed.init_process_group(backend='tcp')
my_rank = torch.distributed.get_rank()
num_processes = torch.distributed.get_num_processes()
...
if my_rank == 0:
torch.distributed.send(tensor, 1)
else:
tensor = torch.distributed.recv(0)
...
result = torch.distributed.all_reduce(tensor)
Master-worker mode
This would provide a very similar API to the torch.cuda
package. At the beginning of your script you would have to call torch.distributed.init_master_worker(backend='mpi')
Operation execution is asynchronous w.r.t. to the master process, we'll implement a CUDA-like concurrency model (streams + events). Until then, the only sync points are copies between master and workers.
Example:
import torch.distributed
torch.distributed.init_master_worker(backend='tcp')
x = torch.distributed.FloatTensor(20, 20).fill_(4)
y = torch.randn(20, 20).dist_send()
z = x + y
# z.get_node(), z.get_device() == 0, -1 (i.e. CPU)
cuda_x = x.cuda()
# cuda_x.get_node(), cuda_x.get_device() == 0, 0
with torch.distributed.node(1):
a = torch.distributed.FloatTensor(10, device=1)
# a.get_node(), a.get_device() == 1, 1
cuda_y = y.cuda()
# cuda_y.get_node(), cuda_y.get_device() == 0, 0
q = cuda_x + cuda_y
# q.get_node(), q.get_device() == 0, 0
How to launch the jobs
We'll provide a pytorch_exec
utility that will spawn the process groups in a similar fashion that mpiexec
does.
Decoupling data backends from other logic
You might have noticed that both init_process_group
and init_master_worker
accept a backend
argument. We're aware that the best strategy for sending the data might be different for every user, and it will be crucial to pick a good one to limit communication overhead. This was the reason why we decided to introduce a DataChannel
interface, so users will be able to pick from one of the provided implementations (initially MPI and raw TCP sockets, later RDMA etc.), or add custom ones, so they can easily achieve the lowest overhead possible in their setup.
Please let us know what you think! Thanks!