KEMBAR78
Fix bnb fsdp loading for pre-quantized checkpoint by SunMarc · Pull Request #41415 · huggingface/transformers · GitHub
Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Oct 7, 2025

What does this PR do?

This PR fixes bnb loading when using FSDP for pre-quantized checkpoints. This happened because we changed how we load quantized checkpoints as we need to cache all the quantized stats before creating the quantized weight.

@SunMarc SunMarc changed the title Fix bnb fsdp loading Fix bnb fsdp loading for pre-quantized checkpoint Oct 7, 2025
@SunMarc SunMarc requested a review from Cyrilvallez October 7, 2025 15:31
@SunMarc
Copy link
Member Author

SunMarc commented Oct 7, 2025

cc @winglian

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc added the for patch Tag issues / labels that should be included in the next patch label Oct 7, 2025
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments about naming for clarity, otherwise LGTM!

Comment on lines 772 to 773
val_kwargs = value.__dict__
if value.dtype in [torch.uint8, torch.int8]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just value.is_floating_point() if that works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should work ! I think that at some point it should be fine to even remove that if the modules are correctly initialized

if value.dtype in [torch.uint8, torch.int8]:
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
param_to = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just call it device IMO, param_to is a bit weird

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 157 to 158
def update_param_name(self, param_name: str) -> str:
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe call it get_param_name instead as it does not update it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines -769 to +770
# special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu
if model.config.model_type == "gpt_oss" and value.device.type == "meta":
# We need to wait until the quantized value is created
if value.device.type == "meta":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still a bit weird to me that we have to do this, but I wanted to investigate further anyway to remove the gpt-oss special exception - already happy to see it a bit more general and not gpt-oss-specific!

@github-actions
Copy link
Contributor

github-actions bot commented Oct 9, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: mxfp4

@SunMarc SunMarc merged commit 823fab4 into main Oct 9, 2025
26 checks passed
@SunMarc SunMarc deleted the fix-fsdp-quant branch October 9, 2025 16:05
AhnJoonSung pushed a commit to AhnJoonSung/transformers that referenced this pull request Oct 12, 2025
* fix

* fix

* get_param_name

* fix device name
Cyrilvallez pushed a commit that referenced this pull request Oct 14, 2025
* fix

* fix

* get_param_name

* fix device name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

for patch Tag issues / labels that should be included in the next patch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants