KEMBAR78
[DTensor][FSDP] dynamo internal error for fully_shard(reshard_after_forward=2) · Issue #122459 · pytorch/pytorch · GitHub
Skip to content

[DTensor][FSDP] dynamo internal error for fully_shard(reshard_after_forward=2) #122459

@weifengpy

Description

@weifengpy

🐛 Describe the bug

repro pytest test/distributed/_composable/fsdp/test_fully_shard_training_compile.py. might share the same root cause with #122447

test_fully_shard_training_compile.py is defined as follows

import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests


class TestFullyShard1DTrainingCore(FSDPTest):
    @property
    def world_size(self) -> int:
        return min(4, torch.cuda.device_count())

    @skip_if_lt_x_gpu(2)
    def test_train_parity_multi_group(self):
        torch.manual_seed(42)
        lin_dim = 32
        reshard_after_forward = 2
        model = MLP(lin_dim, torch.device("cuda"))
        fully_shard(model.in_proj, reshard_after_forward=reshard_after_forward)
        fully_shard(model.out_proj, reshard_after_forward=reshard_after_forward)
        fully_shard(model, reshard_after_forward=reshard_after_forward)
        model = torch.compile(model)
        inp = torch.randn((8, lin_dim), device=torch.device("cuda"))
        model(inp)

if __name__ == "__main__":
    run_tests()

error msg

=================================== FAILURES ===================================
__________ TestFullyShard1DTrainingCore.test_train_parity_multi_group __________
Traceback (most recent call last):
  File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 540, in wrapper
    self._join_processes(fn)
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 759, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 809, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 656, in run_test
    getattr(self, test_name)()
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
    fn()
  File "/data/users/weif/pytorch/torch/testing/_internal/common_utils.py", line 2739, in wrapper
    method(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 181, in wrapper
    return func(*args, **kwargs)
  File "/data/users/weif/pytorch/test/distributed/_composable/fsdp/test_fully_shard_training_compile.py", line 27, in test_train_parity_multi_group
    model(inp)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1554, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/_dynamo/eval_frame.py", line 390, in _fn
    return fn(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1595, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/testing/_internal/common_fsdp.py", line 857, in forward
    z = self.in_proj(x)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1608, in _call_impl
    hook_result = hook(self, args, result)
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py", line 172, in _post_forward
    output = self._fsdp_param_group.post_forward(module, input, output)
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 285, in post_forward
    self.reshard()
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 267, in reshard
    self._to_sharded_post_forward()
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 379, in _to_sharded_post_forward
    fsdp_param.to_sharded_post_forward()
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param.py", line 305, in to_sharded_post_forward
    self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param.py", line 358, in to_sharded_post_forward_dtensor
    return _from_local_no_grad(
  File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_common.py", line 123, in _from_local_no_grad
    return DTensor(
  File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 229, in __new__
    r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 939, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
    result = inner_convert(
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 713, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/weif/pytorch/torch/_dynamo/utils.py", line 264, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 541, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/weif/pytorch/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 503, in transform
    tracer.run()
  File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 2214, in run
    super().run()
  File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 850, in run
    and self.step()
  File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 765, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 921, in STORE_FAST
    loaded_vt.set_name_hint(name)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 58, in realize
    self._cache.realize(self.parents_tracker)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 24, in realize
    self.vt = VariableBuilder(tx, self.source)(self.value)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 274, in __call__
    vt = self._wrap(value)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 424, in _wrap
    return self.wrap_tensor(value)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 1047, in wrap_tensor
    self.assert_not_wrapped_by_this_graph(value)
  File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 978, in assert_not_wrapped_by_this_graph
    if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
  File "/data/users/weif/pytorch/torch/_subclasses/fake_tensor.py", line 123, in is_fake
    attrs, _ = type(x).__tensor_flatten__(x)
  File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 256, in __tensor_flatten__
    return ["_local_tensor"], (self._spec, self.requires_grad)
torch._dynamo.exc.InternalTorchDynamoError: 'DTensor' object has no attribute '_spec'

from user code:
   File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 229, in torch_dynamo_resume_in___new___at_229
    r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

2ab8b34

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @ezyang @msaroufim @bdhirsh @anijain2305 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions