-
Notifications
You must be signed in to change notification settings - Fork 30.9k
[ESM] support attention API #40370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ESM] support attention API #40370
Conversation
|
run-slow: esm |
|
run-slow: esm |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm'] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial thoughts, it would be nice if we could align things with #38301 as we can then refactor things more easily afterwards.
| if self.config._attn_implementation != "flash_attention_2": | ||
| batch_size, seq_length = inputs_embeds.shape[:-1] | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device) | ||
|
|
||
| else: | ||
| # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
| # ourselves in which case we just need to make it broadcastable to all heads. | ||
| extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) | ||
| attention_mask: torch.Tensor = self.get_extended_attention_mask( | ||
| attention_mask, input_shape=(batch_size, seq_length) | ||
| ) | ||
|
|
||
| # If a 2D or 3D attention mask is provided for the cross-attention | ||
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
| if self.config.is_decoder and encoder_hidden_states is not None: | ||
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | ||
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | ||
| if encoder_attention_mask is None: | ||
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | ||
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) | ||
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) | ||
| else: | ||
| encoder_extended_attention_mask = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we align mask creation to the same as in #38301
It will make refactoring easier and the mask creations are "more proven".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, also thought of it, but after seeing it supports non-causal mask as well I think we need a cleaner approach for that in the future. Some kind of a small function that would decide which mask to construct and return it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not meaning the attention mask interface yet as that definitely needs an update for non-causal variants :D thinking of
transformers/src/transformers/models/bert/modeling_bert.py
Lines 990 to 1034 in e0f1e83
| if attention_mask is None: | |
| # required mask seq length can be calculated via length of past cache | |
| mask_seq_length = past_key_values_length + seq_length | |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=device) | |
| if self.config.is_decoder and encoder_hidden_states is not None and encoder_attention_mask is None: | |
| encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=device) | |
| if attention_mask.dim() == 2: | |
| if self.config.is_decoder: | |
| attention_mask = create_causal_mask( | |
| config=self.config, | |
| input_embeds=embedding_output, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| ) | |
| else: | |
| attention_mask = self._update_full_mask( | |
| attention_mask, | |
| embedding_output, | |
| ) | |
| elif attention_mask.dim() == 3: | |
| if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: | |
| raise ValueError( | |
| "Passing attention mask with a 3D/4D shape does not work with type " | |
| f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead." | |
| ) | |
| attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) | |
| if encoder_attention_mask is not None: | |
| if encoder_attention_mask.dim() == 2: | |
| encoder_attention_mask = self._update_cross_attn_mask( | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| embedding_output.shape[:2], | |
| embedding_output, | |
| ) | |
| else: | |
| if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: | |
| raise ValueError( | |
| "Passing attention mask with a 3D/4D shape does not work with type " | |
| f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead." | |
| ) | |
| encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
(without the checks on the dims). It would make it easy to remove all those functions later on if we keep it consistent across models that have yet to get the mask interface. (And I have tested it quite thoroughly on all attention variations)
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, commented on the modular generated file but should be carried over to esm - my bad ^^'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one small nit (order of relative scaling) and could you change the mask creation per https://github.com/huggingface/transformers/pull/40370/files#r2297912836
|
run-slow: esm, evolla |
|
This comment contains run-slow, running the specified jobs: models: ['models/esm', 'models/evolla'] |
|
run-slow: esm, evolla |
| self.scaling = 1.0 # For BC we apply scaling before RoPE | ||
| self.is_decoder = config.is_decoder | ||
| self.layer_idx = layer_idx | ||
| self.is_causal = self.is_decoder # used only in FA2/FA3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sdpa uses this as well :D this is probably incorrect in case of cross-attention, can you add an optional argument here instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, it also has cross attention. Copied from current ESM attention assuming it would be working on main branch
|
[For maintainers] Suggested jobs to run (before merge) run-slow: esm, evolla |
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
Original PR #40370 by zucchini-nlp Original: huggingface/transformers#40370
Merged from original PR #40370 Original: huggingface/transformers#40370
What does this PR do?
Addresses #34954 and updates ESM to supports attention API and modeling outputs