-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Describe the bug
We have been training dtensor off torch nightly (in anticipation for 2.2), and we are very often seeing the loss flatline. We do not see this at all on current nightly (as of 4 days ago), and at this point we are very confident there is a regression/bug in the current release candidate (for 2.2) that breaks FSDP training (at least with dtensor).
Our best guess is one of the two PRs linked fix it:
- [2d] unflatten_tensor on compute stream for DTensorExtension #116559
- [reland] unflatten_tensor on compute stream for DTensorExtension #117020
To be safe, I personally would want to also include the no grad bug fix: - [FSDP] enable autograd in forward prefetching #116792
Versions
Torch 2.2 branch
cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu @kwen2501 @wanchaol @XilunWu @tianyu-l
wanchaol
Metadata
Metadata
Assignees
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
