-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🐛 Describe the bug
Mini repro:
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
torch.manual_seed(0)
def main(rank, world_size):
device = torch.device("cuda:%d" % rank)
torch.cuda.set_device(device)
dist.init_process_group(
backend="nccl", rank=rank, world_size=world_size, device_id=device,
)
mesh = dist.init_device_mesh("cuda", (world_size,))
dim = 128
x = torch.randn(8, dim, device=device)
A = torch.randn(dim, dim, device=device)
y = torch.matmul(x, A)
# DTensor test
dx = distribute_tensor(x, mesh, [Replicate()])
dA = distribute_tensor(A, mesh, [Shard(0)])
with torch.inference_mode():
dy = torch.ops.aten.matmul.default(dx, dA)
torch.testing.assert_close(y, dy.full_tensor())
dist.destroy_process_group()
print("clean exit")
if __name__ == "__main__":
main(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))
Failure:
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/users/kw2501/matmul/repro.py", line 49, in <module>
[rank0]: main(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))
[rank0]: File "/data/users/kw2501/matmul/repro.py", line 32, in main
[rank0]: dy = torch.ops.aten.matmul.default(dx, dA)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/_ops.py", line 723, in __call__
[rank0]: return self._op(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/_compile.py", line 32, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/_dynamo/eval_frame.py", line 744, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank0]: self.sharding_propagator.propagate(op_info)
[rank0]: File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate
[rank0]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 46, in __call__
[rank0]: return self.cache(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 455, in propagate_op_sharding_non_cached
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: Operator aten.matmul.default does not have a sharding strategy registered.
That is, DTensor dispatcher complains that there is no strategy registered with aten.matmul.
Per @bdhirsh @fduwjj , the reason of the failure is that in inference mode, torch dispatcher chooses not to decompose certain ops -- aten.matmul being one of them.
A solution would be for DTensor dispatcher to actively decompose aten.matmul, which may result in aten.mm. The latter has sharding strategy support with DTensor.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Versions
nightly as of 12042024
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue