KEMBAR78
[RFC] XLA Lazy Backend Support In DistributedTensor API · Issue #92909 · pytorch/pytorch · GitHub
Skip to content

[RFC] XLA Lazy Backend Support In DistributedTensor API #92909

@yeounoh

Description

@yeounoh

🚀 The feature, motivation and pitch

TL;DR

The proposed DistributedTensor provides a new abstraction to express tensor distributions with both sharding and replication parallelism strategies in eager mode and non-lazy backends, like cuda. We propose to integrate XLAShardedTensor and mark_sharding API integration for xla lazy-backend support in the DistributedTensor API. Our goal is to allow PyTorch users to shard a big tensor across xla devices with just a few lines of code:

import torch
from torch.distributed import DeviceMesh, Shard, distribute_tensor

mesh = DeviceMesh("xla", list(range(world_size)))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])

This example is from the DistributedTensor [RFC], with a main difference being the device type xla.

Motivation

The proposed DistributedTensor APIs (distribute_tensor, distribute_module) allow the user to express various types of tensor distributions with just a few lines of code. While simple and generic enough to express many common parallelism paradigms, its current support for backend devices does not entail lazy backends, like xla. PyTorch/XLA offers a set of lower-level XLAShardedTensor APIs that exposes sharding annotations for the tensors residing on the xla devices. Both DistributedTensor and XLAShardedTensor support sharding and replication parallelism strategies, defined by a logical device mesh and a sharding placement spec. Here, we propose to integrate the low-level XLAShardedTensor APIs into the high-level DistributedTensor APIs, so that a user can use the same set of DistributedTensor APIs to express tensor distributions with both sharding and replication parallelism strategies.

Pitch

We integrate xla backend specific XLAShardedTensor APIs into the high-level DistributedTensor APIs (distribute_tensor, distribute_module) so the user can use the same DistributedTensor APIs to express tensor distributions (sharding or replication) on CPU, GPU and xla backend devices, like TPU. Some restrictions apply to the tensor distributions on xla backend: partial tensor distribution is only available in DistributedTensor native backends, as the strategy is “forbidden from constructor” and only used for the intermediary results (tensors); XLAShardedTensor APIs may propagates sharding and currently assume a fixed device assignments to the logical mesh; the output tensor(s) is replicated unless sharded explicitly by the user.

The call to the high-level DistributedTensor API can easily be translated into the low-level XLAShardedTensor API based on the following conversions:

Conversions

DeviceMesh <> Mesh

DistributedTensor API (e.g., distribute_tensor(...)) works with a device mesh, declared by a DeviceMesh instance. It is a subclass of torch.device, and describes the sharded tensor or module placements. For instance, the following mesh defines a 1-by-4 logical device mesh:

# DistributedTensor API
dt_mesh = DeviceMesh("xla", [0, 1, 2, 3])

The first argument is the device type, “xla”, and the mesh is described by the list of logical device IDs (global rank), [0, 1, 2, 3], which implies a single host (per row) with 4 devices. If the mesh is defined with “xla”, then the DistributedTensor API can call the XLAShardedTensor API with the same mesh topology with a shape (1, 4):

from torch_xla.distributed.xla_sharding import Mesh

# XLAShardedTensor API
mesh_shape = (1, 4)
device_ids = [0, 1, 2, 3]  # re-arranged per mesh_shape
axes_name = (‘x’, ‘y’)  # optional
xla_mesh = Mesh(device_ids, mesh_shape, axes_name)

The conversion from DistributedTensor DeviceMesh to XLAShardedTensor Mesh is straightforward:

dt_mesh = DeviceMesh("xla", [0, 1, 2, 3])
dt_mesh.mesh.shape
>> (1, 4)

def convert_to_xla_mesh(mesh: DeviceMeshBase):
  assert torch.numel(dt_mesh.mesh) == len(xm.xrt_world_size())
  return Mesh(dt_mesh.mesh.flatten(), dt_mesh.shape)


xla_mesh = convert_to_xla_mesh(dt_mesh)
xla_mesh.mesh_shape
>> (1, 4)

We can also define DeviceMeshBase for some common properties and interface between DeviceMesh and Mesh:

class MeshBase(abc.ABC):

device_type: str
mesh: torch.Tensor

def __init__(self, device_type, mesh):
  self.device_type = device_type
  self.mesh = mesh

@property
def shape(self):
  return self.mesh.shape

@property
def ndim(self):
  return self.mesh.ndim



# torch.distributed._tensor.device_mesh
class DeviceMesh(MeshBase):
 
 def __init__(self, device_type: str, mesh: MeshExprT, dim_groups: Optional[List[ProcessGroup]] = None, ) -> None:
   mesh = (mesh.detach() if isinstance(mesh, torch.Tesnor) else torch.tensor(mesh, dtype=torch.int))
   super().__init__(device_type, mesh)
   ...



# torh_xla.experimental.xla_sharding
class Mesh(MeshBase):

 def __init__(self, device_ids: Union[np.ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Tuple[str, ...] = None):
   if not isinstance(device_ids, np.ndarray):
     device_ids = np.array(device_ids)
   mesh = torch.tensor(device_ids.reshape(mesh_shape))
   super().__init__("xla", mesh)
   ...

List[Placement] <> Tuple[int, None]

One can convert the DistributedTensor placement specs into the XLAShardedTensor partitioning specs by mapping the “per mesh dimension sharding” (DistributedTensor) to the “per tensor dimension sharding” (XLAShardedTensor). For an illustration, consider an input tensor of shape (4, 8, 8) and its sharding across a (2, 4) device mesh. Namely, the first tensor dimension will be sharded 4-way across the second dimension of the device mesh, and the rest will be replicated.

In DistributedTensor, this is expressed with a placement spec, [Replicate(), Shard(0)] where each of the spec elements describes how the corresponding mesh dimension will be used, replicated or sharded. Finally, Shard(0) means that the first dimension of the input tensor (index 0) will be sharded, in this case over the second dimension of the mesh.

import torch
from torch.distributed.tensor import distribute_tensor, DeviceMesh, Shard, Replicate

m1 = torch.randn(4, 8, 8)

# Mesh partitioning, each device holds 1/4-th of the input with 
# replicated overlaps. The first input tensioner dimension is split 4-way.
dt_mesh = DeviceMesh("cuda", torch.arange(8).reshape(2, 4))
m1_sharded = distribute_tensor(m1, dt_mesh, [Replicate(), Shard(0)])

In XLAShardedTensor, the same sharding strategy is denoted by a partition spec, (1, None, None). Each spec element describes how the corresponding input tensor dimension will be mapped to the device mesh. For example, partition_spec[0] = 1 indicates that the first dimension of the input tensor will be mapped to the second dimension (index 1) of the device mesh, thus split 4-way. None means replication, and the rest of the input dimensions will be replicated.

import torch
import torch_xla.distributed.xla_sharding as xs
from torch_xla.distributed.xla_sharding import Mesh

m1 = torch.randn(4, 8, 8).to(xm.xla_device())
xla_mesh = Mesh(torch.arange(8), (2,4))

# Mesh partitioning, each device holds 1/4-th of the input with 
# replicated overlaps. The first input tensioner dimension is split 4-way.
partition_spec = (1, None, None)
m1_sharded = xs.mark_sharding(m1, mesh, partition_spec)

Note that the XLAShardedTensor uses a different sharding spec representation, where a sharding strategy is declared “per tensor dimension”. We can transform DT placement specs (Shard or Replicate) into partition specs,

m1 = torch.randn(4, 8, 8)

def convert_to_xla_partition_spec(tensor: torch.Tensor, placement_spec: List[Placement]):
  # per tensor dimension sharding
  sharding_spec = tuple([None] * len(tensor.shape))
  for mesh_idx, spec in enumerate(placement_spec):
    if instance(spec, Shard):
      # mesh_idx to tensor_idx (spec.dim)
      sharding_spec[spec.dim] = mesh_idx
    # Replicate defaults to None

sharding_spec = convert_to_xla_partition_spec(m1, [Replicate(), Shard(0)])
print(sharding_spec)
>> (1, None, None)

DistributedTensor <> XLAShardedTensor

Tensor distributions on the xla backend triggers the XLA compiler to partition and propagates the sharding, the final result is the same as if the computation were not sharded, and the result is replicated across the devices. This is the side-effects of the xla backend tensor distribution. One can avoid such side-effects and just apply torch ops to the sharded tensors, by taking the returned XLAShardedTensor and converting it to DistributedTensor. This conversion requires that the DistributedTensor resides on the CPU.

# distribute_tensor with xla backend returns XLAShardedTensor
t = torch.randn(4, 8)
mesh_shape = torch.arange(8).reshape(2, 4)
xla_mesh = DeviceMesh("xla", mesh_shape)
xt = distribute_tensor(t, xla_mesh, [Replicate(), Shard(0)])


# XLAShardedTensor is collected on the host for the conversion
cpu_mesh = DeviceMesh(“cpu”, [0, 1])
dt = DistributedTensor.from_local(xt.global_tensor.to("cpu"), cpu_mesh, [Replicate(), Shard(0)])

# DistributedTensor can be converted to XLAShardedTensor
xt = xs.mark_sharding(dt.global_tensor, tuple(mesh_shape), (1, None))

DistributedTensor API with xla device

distribute_tensor

Calling distribute_tensor with an xla device_mesh will trigger a mark_sharding API call with the transformed input arguments:

def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh=None, placements: List[Placement]=None) -> torch.Tensor:
    # distribute the tensor according to device_mesh and placements, tensor could be a "meta" tensor.
    ...
    
    # Annotates sharding and returns an XLAShardedTensor
    if device_mesh.device_type == 'xla':
      # import torch_xla.experimental.xla_sharding as xs
      xla_mesh = convert_to_xla_mesh(device_mesh)
      partition_spec = convert_to_xla_partition_spec(tensor, placements)
      xla_tensor = xs.mark_sharding(tensor, xla_mesh, parittion_spec)
      return xla_tensor

The distribute_tensor API returns a torch.Tensor that can be either DistributedTensor or XLAShardedTensor.

distributed_module

This API is currently mainly used for manual sharding specification, not like GSPMD automatic style sharding propagation, i.e. it allows the user to specify sharding, and treat the rest of the module parameters as replicated. Currently we are in the process of deciding if we want to use this API or a new API to do GSPMD style sharding propagation. We can revisit this with XLA GSPMD integration later if we settled in the API.

Alternatives

We want to make the DistributedTensor API to be device agnostic and also support the xla lazy backend. PyTorch/XLA provides a set of lower-level APIs which can be integrated into DT to support the distributed tensor execution on the lazy backend, with some limitations. The goal is to promote more consistent user experiences across different backends, and use the same abstraction as possible. An alternative is to integrate into other distributed tensor abstractions and their APIs, which we may consider after integrating with DT first, if need to.

cc @bdhirsh @wanchaol @JackCaoG @steventk-g @fduwjj @alanwaketan @miladm

Metadata

Metadata

Labels

module: xlaRelated to XLA supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions