-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fix bnb fsdp loading for pre-quantized checkpoint #41415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @winglian |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments about naming for clarity, otherwise LGTM!
src/transformers/modeling_utils.py
Outdated
| val_kwargs = value.__dict__ | ||
| if value.dtype in [torch.uint8, torch.int8]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just value.is_floating_point() if that works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should work ! I think that at some point it should be fine to even remove that if the modules are correctly initialized
src/transformers/modeling_utils.py
Outdated
| if value.dtype in [torch.uint8, torch.int8]: | ||
| val_kwargs["requires_grad"] = False | ||
| value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) | ||
| param_to = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just call it device IMO, param_to is a bit weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| def update_param_name(self, param_name: str) -> str: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe call it get_param_name instead as it does not update it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| # special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu | ||
| if model.config.model_type == "gpt_oss" and value.device.type == "meta": | ||
| # We need to wait until the quantized value is created | ||
| if value.device.type == "meta": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still a bit weird to me that we have to do this, but I wanted to investigate further anyway to remove the gpt-oss special exception - already happy to see it a bit more general and not gpt-oss-specific!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mxfp4 |
* fix * fix * get_param_name * fix device name
* fix * fix * get_param_name * fix device name
What does this PR do?
This PR fixes bnb loading when using FSDP for pre-quantized checkpoints. This happened because we changed how we load quantized checkpoints as we need to cache all the quantized stats before creating the quantized weight.