KEMBAR78
[fsdp2] based on device, use stream and Event by jeejakp12 · Pull Request #136843 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jeejakp12
Copy link
Contributor

@jeejakp12 jeejakp12 commented Sep 27, 2024

currently FSDP2 support only CUDA, for other backends that need to use FSDP2 it won’t work as stream and events are based on CUDA. To support other backends, use
_get_device_handle by device type to get the class and use this
for stream and events.

Fixes #ISSUE_NUMBER

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136843

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7750cda with merge base f54e142 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 27, 2024
@Skylion007 Skylion007 requested a review from awgu September 27, 2024 13:54
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 27, 2024
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
# hpu need to add support fot AVG, just change it to proceed, will see accuracy issue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe can you not include this comment unless you are confident that you are going to fix it and then remove this comment later?

otherwise, I would prefer you track the HPU to-do separately from the main code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu sorry that was temp fix for verifying it on HPU, forgot to remove it. fixed in the latest patch.

@jeejakp12 jeejakp12 force-pushed the origin/jeeja_use_device_handle_for_stream_event branch from bb94015 to 949e2ce Compare September 27, 2024 15:59
@awgu
Copy link
Collaborator

awgu commented Sep 27, 2024

change looks good to me overall, let me let CI run first, and I will do a second pass

@awgu
Copy link
Collaborator

awgu commented Sep 27, 2024

@jeejakp12 sorry could you also fix the failing unit tests? I think it should not be a complicated fix fortunately

@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Sep 27, 2024
@jeejakp12 jeejakp12 force-pushed the origin/jeeja_use_device_handle_for_stream_event branch from 949e2ce to ea62d93 Compare September 30, 2024 09:38
@jeejakp12
Copy link
Contributor Author

@pytorchbot drci

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please try to minimize the code surface affected? For example, let us not add default args if not needed; let us not add None paths if not needed, etc.

I left some comments regarding these inline.

all_reduce_grads: bool,
partial_reduce_output: Optional[torch.Tensor], # only used for HSDP
) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, torch.Event, torch.Event, Optional[torch.Tensor]]:
device_handle = _get_device_handle_from_device(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let us not put code before the block comment 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix it

@@ -150,3 +151,16 @@ def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
):
return x
return x.to(dtype)


def _get_device_handle_from_device(device: Optional[torch.device] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us not use a default arg if we never expect to call this without passing a device

Comment on lines 157 to 158
if device is None:
device_type = "cuda" if torch.cuda.is_available() else "cpu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we expect to pass device=None?
if never, let us get rid of this branch

I would even vote to directly just use _get_device_handle(device.type) instead of this function because we do not really need this error checking on all calls during runtime

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below is the test which fails, when FSDPParamGroup is created with device None.
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the use_training_state ctx manager.
param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
None, # device: torch.device,
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can we try to pass a real device here?

device_handle = _get_device_handle(device_type)

if device_handle is None:
raise RuntimeError("InValid device handle for device type:", device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you decide to keep this

Suggested change
raise RuntimeError("InValid device handle for device type:", device)
raise RuntimeError("Invalid device handle for device type:", device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix this

if not torch.cuda.is_available():
raise RuntimeError("FSDP requires CUDA for streams")
def lazy_init(self, device: torch.device):
if device is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we hit this branch? it mismatches the type annotation of device: torch.device

if we do not hit it, we should get rid of it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if FSDPParamGroup is created with device as None, device is None. There were some test where the FSDPParamGroup device passed was None. if we allow None for FSDPParamGroup- torch.device , then we need the check. do you think that we should not allow None here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah makes sense! can we try to pass a proper device in those cases and disallow None here? sorry about this 😢

) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]:
self, async_op: bool, training_state: TrainingState, device: torch.device
) -> Tuple[torch.Stream, torch.Stream]:
device_handle = _get_device_handle_from_device(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use self.device_handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will check and fix this

@jeejakp12
Copy link
Contributor Author

Can we please try to minimize the code surface affected? For example, let us not add default args if not needed; let us not add None paths if not needed, etc.

I left some comments regarding these inline.

i will try to rework and try to use directly get_device_handle()

@jeejakp12 jeejakp12 force-pushed the origin/jeeja_use_device_handle_for_stream_event branch 2 times, most recently from 0ab1669 to 1e00f21 Compare October 4, 2024 14:31
@jeejakp12
Copy link
Contributor Author

@awgu reworked and used directly get_device_handle(). can you please help review

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working through this.

@jeejakp12 jeejakp12 force-pushed the origin/jeeja_use_device_handle_for_stream_event branch from 1e00f21 to 9afa427 Compare October 4, 2024 17:53
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 4, 2024
@awgu
Copy link
Collaborator

awgu commented Oct 4, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@awgu
Copy link
Collaborator

awgu commented Oct 4, 2024

test failure looks real
will need to debug

@awgu
Copy link
Collaborator

awgu commented Oct 4, 2024

This is weird. I patched your PR but cannot repro locally:

 pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k  test_to_float64_after_init  -s

The above passes for me. Maybe there was a regression in trunk.

Edit: ah, indeed there was

BROKEN TRUNK - The following job failed but were present on the merge base:
👉 Rebase onto the viable/strict branch to avoid these failures

@awgu
Copy link
Collaborator

awgu commented Oct 4, 2024

For pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 2, 3, lf.linux.8xlarge.nvidia.gpu) (gh), the issue is that we did not previous assume that lazy init had to be called before running unshard. I will revisit this in a bit.

@awgu
Copy link
Collaborator

awgu commented Oct 4, 2024

@jeejakp12 could you just do this for now to unblock:

def lazy_init(self):
        # Lazy init should be idempotent
+        if not hasattr(self.comm_ctx, "device_handle"):
+            self.comm_ctx.device_handle = _get_device_handle(self.device.type)

in _fsdp_param_group.py

currently FSDP2 support only CUDA, for other backends
that need to use FSDP2 it won’t work as stream and events
are based on CUDA. To support other backends, use
 _get_device_handle by device type to get the class and use this
for stream and events.

Signed-off-by: Jeeja <jeejakp@habana.ai>
@jeejakp12 jeejakp12 force-pushed the origin/jeeja_use_device_handle_for_stream_event branch from 9afa427 to 7750cda Compare October 5, 2024 15:18
@jeejakp12
Copy link
Contributor Author

@jeejakp12 could you just do this for now to unblock:

def lazy_init(self):
        # Lazy init should be idempotent
+        if not hasattr(self.comm_ctx, "device_handle"):
+            self.comm_ctx.device_handle = _get_device_handle(self.device.type)

in _fsdp_param_group.py

@awgu i have pushed the above change. Thanks:-)

@cyyever
Copy link
Collaborator

cyyever commented Oct 6, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp2) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants