-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Closed
Copy link
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
repro pytest test/distributed/_composable/fsdp/test_fully_shard_training_compile.py
. might share the same root cause with #122447
test_fully_shard_training_compile.py
is defined as follows
import torch
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests
class TestFullyShard1DTrainingCore(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_multi_group(self):
torch.manual_seed(42)
lin_dim = 32
reshard_after_forward = 2
model = MLP(lin_dim, torch.device("cuda"))
fully_shard(model.in_proj, reshard_after_forward=reshard_after_forward)
fully_shard(model.out_proj, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
model = torch.compile(model)
inp = torch.randn((8, lin_dim), device=torch.device("cuda"))
model(inp)
if __name__ == "__main__":
run_tests()
error msg
=================================== FAILURES ===================================
__________ TestFullyShard1DTrainingCore.test_train_parity_multi_group __________
Traceback (most recent call last):
File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
yield
File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 540, in wrapper
self._join_processes(fn)
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 759, in _join_processes
self._check_return_codes(elapsed_time)
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 809, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 656, in run_test
getattr(self, test_name)()
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
fn()
File "/data/users/weif/pytorch/torch/testing/_internal/common_utils.py", line 2739, in wrapper
method(*args, **kwargs)
File "/data/users/weif/pytorch/torch/testing/_internal/common_distributed.py", line 181, in wrapper
return func(*args, **kwargs)
File "/data/users/weif/pytorch/test/distributed/_composable/fsdp/test_fully_shard_training_compile.py", line 27, in test_train_parity_multi_group
model(inp)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1554, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/weif/pytorch/torch/_dynamo/eval_frame.py", line 390, in _fn
return fn(*args, **kwargs)
File "/data/users/weif/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1595, in _call_impl
result = forward_call(*args, **kwargs)
File "/data/users/weif/pytorch/torch/testing/_internal/common_fsdp.py", line 857, in forward
z = self.in_proj(x)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1545, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/weif/pytorch/torch/nn/modules/module.py", line 1608, in _call_impl
hook_result = hook(self, args, result)
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py", line 172, in _post_forward
output = self._fsdp_param_group.post_forward(module, input, output)
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 285, in post_forward
self.reshard()
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 267, in reshard
self._to_sharded_post_forward()
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 379, in _to_sharded_post_forward
fsdp_param.to_sharded_post_forward()
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param.py", line 305, in to_sharded_post_forward
self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_param.py", line 358, in to_sharded_post_forward_dtensor
return _from_local_no_grad(
File "/data/users/weif/pytorch/torch/distributed/_composable/fsdp/_fsdp_common.py", line 123, in _from_local_no_grad
return DTensor(
File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 229, in __new__
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 939, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
result = inner_convert(
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/home/weif/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 713, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/users/weif/pytorch/torch/_dynamo/utils.py", line 264, in time_wrapper
r = func(*args, **kwargs)
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 541, in compile_inner
out_code = transform_code_object(code, transform)
File "/data/users/weif/pytorch/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/data/users/weif/pytorch/torch/_dynamo/convert_frame.py", line 503, in transform
tracer.run()
File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 2214, in run
super().run()
File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 850, in run
and self.step()
File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 765, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/users/weif/pytorch/torch/_dynamo/symbolic_convert.py", line 921, in STORE_FAST
loaded_vt.set_name_hint(name)
File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 58, in realize
self._cache.realize(self.parents_tracker)
File "/data/users/weif/pytorch/torch/_dynamo/variables/lazy.py", line 24, in realize
self.vt = VariableBuilder(tx, self.source)(self.value)
File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 274, in __call__
vt = self._wrap(value)
File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 424, in _wrap
return self.wrap_tensor(value)
File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 1047, in wrap_tensor
self.assert_not_wrapped_by_this_graph(value)
File "/data/users/weif/pytorch/torch/_dynamo/variables/builder.py", line 978, in assert_not_wrapped_by_this_graph
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
File "/data/users/weif/pytorch/torch/_subclasses/fake_tensor.py", line 123, in is_fake
attrs, _ = type(x).__tensor_flatten__(x)
File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 256, in __tensor_flatten__
return ["_local_tensor"], (self._spec, self.requires_grad)
torch._dynamo.exc.InternalTorchDynamoError: 'DTensor' object has no attribute '_spec'
from user code:
File "/data/users/weif/pytorch/torch/distributed/_tensor/api.py", line 229, in torch_dynamo_resume_in___new___at_229
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Versions
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @ezyang @msaroufim @bdhirsh @anijain2305 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng
Metadata
Metadata
Assignees
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module