KEMBAR78
[DTensor] aten.matmul NotImplementedError in inference_mode · Issue #142190 · pytorch/pytorch · GitHub
Skip to content

[DTensor] aten.matmul NotImplementedError in inference_mode #142190

@kwen2501

Description

@kwen2501

🐛 Describe the bug

Mini repro:

import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor, Shard, Replicate


torch.manual_seed(0)

def main(rank, world_size):
    device = torch.device("cuda:%d" % rank)
    torch.cuda.set_device(device)
    dist.init_process_group(
        backend="nccl", rank=rank, world_size=world_size, device_id=device,
    )
    mesh = dist.init_device_mesh("cuda", (world_size,))

    dim = 128
   
    x = torch.randn(8, dim, device=device)
    A = torch.randn(dim, dim, device=device)
    y = torch.matmul(x, A)

    # DTensor test
    dx = distribute_tensor(x, mesh, [Replicate()])
    dA = distribute_tensor(A, mesh, [Shard(0)])
    with torch.inference_mode():
        dy = torch.ops.aten.matmul.default(dx, dA)

    torch.testing.assert_close(y, dy.full_tensor())

    dist.destroy_process_group()
    print("clean exit")

if __name__ == "__main__":
    main(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))

Failure:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/kw2501/matmul/repro.py", line 49, in <module>
[rank0]:     main(int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]))
[rank0]:   File "/data/users/kw2501/matmul/repro.py", line 32, in main
[rank0]:     dy = torch.ops.aten.matmul.default(dx, dA)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/_ops.py", line 723, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/_compile.py", line 32, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/_dynamo/eval_frame.py", line 744, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/distributed/tensor/_dispatch.py", line 169, in dispatch
[rank0]:     self.sharding_propagator.propagate(op_info)
[rank0]:   File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate
[rank0]:     OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 46, in __call__
[rank0]:     return self.cache(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/users/kw2501/pytorch/torch/distributed/tensor/_sharding_prop.py", line 455, in propagate_op_sharding_non_cached
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: Operator aten.matmul.default does not have a sharding strategy registered.

That is, DTensor dispatcher complains that there is no strategy registered with aten.matmul.

Per @bdhirsh @fduwjj , the reason of the failure is that in inference mode, torch dispatcher chooses not to decompose certain ops -- aten.matmul being one of them.

A solution would be for DTensor dispatcher to actively decompose aten.matmul, which may result in aten.mm. The latter has sharding strategy support with DTensor.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Versions

nightly as of 12042024

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions