-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[WIP] Adding GPT2 with Multi Query Attention #21253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Regarding tests |
| attn_weights = attn_weights.view(batch_size, self.num_heads, query_length, key_length) | ||
|
|
||
| if self.scale_attn_weights: | ||
| attn_weights = attn_weights / torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synchronisation issue #20061
Should be attn_weights = attn_weights / value.size(-1) ** 0.5 (that PR isn't great, we don't want to create a tensor here)
| mask_value = torch.finfo(attn_weights.dtype).min | ||
| # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. | ||
| # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` | ||
| mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other synchronization issue, mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) (this one should be a tensor according to the comment)
| if layer_past is not None: | ||
| past_key, past_value = layer_past | ||
| # Concatenate on sequence dimension | ||
| key = torch.cat((past_key, key), dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the slow op. Note that avoiding this will probably change the return type of the function, since we need to return the buffer cache and some extra information.
| # self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) | ||
| self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | ||
| # Keys and values are shared across heads | ||
| self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's MQA_2 with q and kv separate right? That is likely a bit slower than keeping them together.
| >>> samples_img = [ | ||
| ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples | ||
| ... ] # convert color cluster tokens back to pixels | ||
| >>> ] # convert color cluster tokens back to pixels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No - I thought I removed those. Not sure why they came back 😂
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Closing in favour of #22575 |
Adding GPT2 with Multi Query Attention
This PR adds a GPT2 architecture with Multi Query Attention (MQA). With MQA the V,K weights are shared across heads and only Qs are unique which makes it possible to run the model with very large batches.
This is the Architecture used in BigCode's SantaCoder.
There are a few things to do before we can merge the PR:
pastYou can run the tests with:
cc @bigximik @jlamypoirier @RaymondLi0
To review when ready I tag @ArthurZucker and @younesbelkada.