-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fix modular for modernbert-decoder #40431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: modernbert_decoder |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch
| from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging | ||
| from ...utils.deprecation import deprecate_kwarg | ||
| from ...utils.generic import check_model_inputs | ||
| from ..modernbert.modeling_modernbert import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting I thought models.llama would work, but no
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, will check and probably upstream it to the converter to avoid it in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just nits on my side
| elif module.__class__.__name__ == "ModernBertDecoderForSequenceClassification": | ||
| init_weight(module.classifier, stds["final_out"]) | ||
| elif isinstance(module, ModernBertDecoderForCausalLM): | ||
| elif module.__class__.__name__ == "ModernBertDecoderForCausalLM": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a dumb question but why can't we check for the instance here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one could be, but the other ModernBertDecoderForSequenceClassification would be matched by modular, and thus wrongly imported - for consistency made the check on name on the 2 "higher-level" classes
|
|
||
| @auto_docstring | ||
| class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): | ||
| config: ModernBertDecoderConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be kept no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to redefine, it's inherited
| config: ModernBertDecoderConfig | ||
| _skip_keys_device_placement = ["past_key_values"] | ||
| _no_split_modules = ["ModernBertDecoderLayer"] | ||
| _can_compile_fullgraph = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here? Or change the flag maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's false by default, no need to add it
What does this PR do?
The modular was ill-formed, resulting in mostly skipping all the rules and instead importing the classes from modernbert in the modeling (which is illegal as we want 1-model -> 1-file).
This fixes it.
Note that as it was mistakenly inheriting from
ModernBertPreTrainedModelinstead of using modular rules and correctly rewriting the code, the model was using FA2 by default instead of sdpa before #40350 (review). As it was a mistake and all models should use sdpa by default unless extreme exception (as far as I know, ModernBert is the only one), I did not revert the new sdpa default (it's done naturally now that modular is fixed).