-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fix return typehint for decoder and annotate inv_freq #39610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) | ||
| self.inv_freq: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this typehint? It looks like a weird pattern to annotate it alone like this without attributing a value no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez, later, when we use this attribute in the forward method, it is an unknown member of a class because we do not assign it as usual with self.inv_freq = ..., but instead register it with torch using self.register_buffer(...). Annotating it fixes this linting issue.
This change is not so necessary, it's just a nit, so I can remove it if you think it's redundant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, the annotation can be moved to a class level
class LlamaRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: LlamaConfig, device=None):
...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez, friendly ping 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oups sorry! Yes, would rather have it at the class level, or inline such as
self.register_buffer("inv_freq", inv_freq, persistent=False); self.inv_freq: torch.Tensorbecause it feels weird as 2 successive lines IMO
But probably class level with a comment explaining that it's because linter is not able to follow register_buffer is better, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, sounds good to me, moved to a class level
|
[For maintainers] Suggested jobs to run (before merge) run-slow: arcee, aria, bamba, bitnet, chameleon, cohere, cohere2, csm, dbrx, deepseek_v2, deepseek_v3, dia, diffllama, doge, dots1, efficientloftr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, makes a lot of sense! Thanks for cleaning up the types, it's not the most glamorous but super useful!!
|
Thanks for the review! |


What does this PR do?
Fix return type hint for decoder layer and annotate inv_freq