KEMBAR78
[WIP] Adding GPT2 with Multi Query Attention by lvwerra · Pull Request #21253 · huggingface/transformers · GitHub
Skip to content

Conversation

@lvwerra
Copy link
Member

@lvwerra lvwerra commented Jan 23, 2023

Adding GPT2 with Multi Query Attention

This PR adds a GPT2 architecture with Multi Query Attention (MQA). With MQA the V,K weights are shared across heads and only Qs are unique which makes it possible to run the model with very large batches.

This is the Architecture used in BigCode's SantaCoder.

There are a few things to do before we can merge the PR:

  • add performance improvements suggested by @jlamypoirier
  • fix tests:
    • there is an issue with past
    • there is an issue with loading the tokenizer (i guess missing vocab file in repo?)
    • fix the generation examples

You can run the tests with:

RUN_SLOW=1 python -m pytest -s -v ./tests/models/gpt2mqa/

cc @bigximik @jlamypoirier @RaymondLi0

To review when ready I tag @ArthurZucker and @younesbelkada.

@bigximik
Copy link

Regarding tests test_batch_generation and test_batch_generation_2heads. If token initialisation class is changed form GPT2Tokenizer to GPT2TokenizerFast the test passes through until generated tokens assertion. Is it intended behaviour or the loading functionality should have rerouted from the default class?

attn_weights = attn_weights.view(batch_size, self.num_heads, query_length, key_length)

if self.scale_attn_weights:
attn_weights = attn_weights / torch.tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synchronisation issue #20061
Should be attn_weights = attn_weights / value.size(-1) ** 0.5 (that PR isn't great, we don't want to create a tensor here)

mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other synchronization issue, mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) (this one should be a tensor according to the comment)

if layer_past is not None:
past_key, past_value = layer_past
# Concatenate on sequence dimension
key = torch.cat((past_key, key), dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the slow op. Note that avoiding this will probably change the return type of the function, since we need to return the buffer cache and some extra information.

# self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
# Keys and values are shared across heads
self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's MQA_2 with q and kv separate right? That is likely a bit slower than keeping them together.

>>> samples_img = [
... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
... ] # convert color cluster tokens back to pixels
>>> ] # convert color cluster tokens back to pixels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these intended?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No - I thought I removed those. Not sure why they came back 😂

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@lvwerra
Copy link
Member Author

lvwerra commented Apr 21, 2023

Closing in favour of #22575

@lvwerra lvwerra closed this Apr 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants