-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[fsdp2] based on device, use stream and Event #136843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fsdp2] based on device, use stream and Event #136843
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136843
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7750cda with merge base f54e142 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| dist.all_reduce( | ||
| reduce_output, | ||
| group=all_reduce_group, | ||
| op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM, | ||
| # hpu need to add support fot AVG, just change it to proceed, will see accuracy issue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe can you not include this comment unless you are confident that you are going to fix it and then remove this comment later?
otherwise, I would prefer you track the HPU to-do separately from the main code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu sorry that was temp fix for verifying it on HPU, forgot to remove it. fixed in the latest patch.
bb94015 to
949e2ce
Compare
|
change looks good to me overall, let me let CI run first, and I will do a second pass |
|
@jeejakp12 sorry could you also fix the failing unit tests? I think it should not be a complicated fix fortunately |
949e2ce to
ea62d93
Compare
|
@pytorchbot drci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please try to minimize the code surface affected? For example, let us not add default args if not needed; let us not add None paths if not needed, etc.
I left some comments regarding these inline.
| all_reduce_grads: bool, | ||
| partial_reduce_output: Optional[torch.Tensor], # only used for HSDP | ||
| ) -> Tuple[torch.Tensor, torch.cuda.Event, torch.cuda.Event, Optional[torch.Tensor]]: | ||
| ) -> Tuple[torch.Tensor, torch.Event, torch.Event, Optional[torch.Tensor]]: | ||
| device_handle = _get_device_handle_from_device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let us not put code before the block comment 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will fix it
| @@ -150,3 +151,16 @@ def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor: | |||
| ): | |||
| return x | |||
| return x.to(dtype) | |||
|
|
|||
|
|
|||
| def _get_device_handle_from_device(device: Optional[torch.device] = None): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us not use a default arg if we never expect to call this without passing a device
| if device is None: | ||
| device_type = "cuda" if torch.cuda.is_available() else "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we expect to pass device=None?
if never, let us get rid of this branch
I would even vote to directly just use _get_device_handle(device.type) instead of this function because we do not really need this error checking on all calls during runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below is the test which fails, when FSDPParamGroup is created with device None.
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the use_training_state ctx manager.
param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
None, # device: torch.device,
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Can we try to pass a real device here?
| device_handle = _get_device_handle(device_type) | ||
|
|
||
| if device_handle is None: | ||
| raise RuntimeError("InValid device handle for device type:", device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you decide to keep this
| raise RuntimeError("InValid device handle for device type:", device) | |
| raise RuntimeError("Invalid device handle for device type:", device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will fix this
| if not torch.cuda.is_available(): | ||
| raise RuntimeError("FSDP requires CUDA for streams") | ||
| def lazy_init(self, device: torch.device): | ||
| if device is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we hit this branch? it mismatches the type annotation of device: torch.device
if we do not hit it, we should get rid of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if FSDPParamGroup is created with device as None, device is None. There were some test where the FSDPParamGroup device passed was None. if we allow None for FSDPParamGroup- torch.device , then we need the check. do you think that we should not allow None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah makes sense! can we try to pass a proper device in those cases and disallow None here? sorry about this 😢
| ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]: | ||
| self, async_op: bool, training_state: TrainingState, device: torch.device | ||
| ) -> Tuple[torch.Stream, torch.Stream]: | ||
| device_handle = _get_device_handle_from_device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use self.device_handle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will check and fix this
i will try to rework and try to use directly get_device_handle() |
0ab1669 to
1e00f21
Compare
|
@awgu reworked and used directly get_device_handle(). can you please help review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for working through this.
1e00f21 to
9afa427
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
test failure looks real |
|
This is weird. I patched your PR but cannot repro locally: The above passes for me. Maybe there was a regression in trunk. Edit: ah, indeed there was BROKEN TRUNK - The following job failed but were present on the merge base: |
|
For pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 2, 3, lf.linux.8xlarge.nvidia.gpu) (gh), the issue is that we did not previous assume that lazy init had to be called before running |
|
@jeejakp12 could you just do this for now to unblock: def lazy_init(self):
# Lazy init should be idempotent
+ if not hasattr(self.comm_ctx, "device_handle"):
+ self.comm_ctx.device_handle = _get_device_handle(self.device.type)in |
currently FSDP2 support only CUDA, for other backends that need to use FSDP2 it won’t work as stream and events are based on CUDA. To support other backends, use _get_device_handle by device type to get the class and use this for stream and events. Signed-off-by: Jeeja <jeejakp@habana.ai>
9afa427 to
7750cda
Compare
@awgu i have pushed the above change. Thanks:-) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
currently FSDP2 support only CUDA, for other backends that need to use FSDP2 it won’t work as stream and events are based on CUDA. To support other backends, use
_get_device_handle by device type to get the class and use this
for stream and events.
Fixes #ISSUE_NUMBER
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o