KEMBAR78
Fakeifying a non-leaf subclass where inner tensor is noncontiguous incorrectly produces contiguous tensor. · Issue #124090 · pytorch/pytorch · GitHub
Skip to content

Fakeifying a non-leaf subclass where inner tensor is noncontiguous incorrectly produces contiguous tensor. #124090

@bdhirsh

Description

@bdhirsh

Minified repro from internal:

    def test_dtensor_tensor_is_not_autograd_leaf_but_local_is_noncontiguous(self):

        # Temporarily ignore setUp(), and use rank3 graphs during tracing
        dist.destroy_process_group()
        fake_store = FakeStore()
        dist.init_process_group(
            "fake", store=fake_store, rank=3, world_size=2
        )
        mesh = DeviceMesh(self.device_type, [1, 3])

        x = torch.randn(10, 257, 160, requires_grad=True)
        x_dt = DTensor.from_local(x, mesh, [_Partial()], run_check=False, shape=(10, 257, 160), stride=(41120, 160, 1))
        tmp_dt = x_dt.redistribute(mesh, (Shard(1),))

        from torch._subclasses import FakeTensorMode
        m = FakeTensorMode()
        tmp_dt_fake = m.from_tensor(tmp_dt)
        self.assertEqual(tmp_dt.shape, tmp_dt_fake.shape)
        self.assertEqual(tmp_dt.stride(), tmp_dt_fake.stride())
        self.assertEqual(tmp_dt._local_tensor.shape, tmp_dt_fake._local_tensor.shape)
        # This assert **fails**
        # tmp_dt._local_tensor is not contiguous, but tmp_dt_fake._local_tensor advertises as contiguous
        self.assertEqual(tmp_dt._local_tensor.stride(), tmp_dt_fake._local_tensor.stride())

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anijain2305 @chauhang

Metadata

Metadata

Assignees

Labels

high prioritymodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions