KEMBAR78
[ESM] support attention API by zucchini-nlp · Pull Request #40370 · huggingface/transformers · GitHub
Skip to content

Conversation

@zucchini-nlp
Copy link
Member

What does this PR do?

Addresses #34954 and updates ESM to supports attention API and modeling outputs

@zucchini-nlp
Copy link
Member Author

run-slow: esm

@zucchini-nlp
Copy link
Member Author

run-slow: esm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/esm']
quantizations: [] ...

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member Author

run-slow: esm, evolla

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/esm', 'models/evolla']
quantizations: [] ...

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial thoughts, it would be nice if we could align things with #38301 as we can then refactor things more easily afterwards.

Comment on lines +724 to 742
if self.config._attn_implementation != "flash_attention_2":
batch_size, seq_length = inputs_embeds.shape[:-1]
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device)

else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape=(batch_size, seq_length)
)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we align mask creation to the same as in #38301

It will make refactoring easier and the mask creations are "more proven".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, also thought of it, but after seeing it supports non-causal mask as well I think we need a cleaner approach for that in the future. Some kind of a small function that would decide which mask to construct and return it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not meaning the attention mask interface yet as that definitely needs an update for non-causal variants :D thinking of

if attention_mask is None:
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
if self.config.is_decoder and encoder_hidden_states is not None and encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=device)
if attention_mask.dim() == 2:
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
elif attention_mask.dim() == 3:
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
if encoder_attention_mask is not None:
if encoder_attention_mask.dim() == 2:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
else:
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
raise ValueError(
"Passing attention mask with a 3D/4D shape does not work with type "
f"{self.config._attn_implementation} - please use either `sdpa` or `eager` instead."
)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)

(without the checks on the dims). It would make it easy to remove all those functions later on if we keep it consistent across models that have yet to get the mask interface. (And I have tested it quite thoroughly on all attention variations)

@zucchini-nlp
Copy link
Member Author

run-slow: esm, evolla

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/esm', 'models/evolla']
quantizations: [] ...

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, commented on the modular generated file but should be carried over to esm - my bad ^^'

@zucchini-nlp zucchini-nlp requested a review from vasqu August 26, 2025 13:24
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one small nit (order of relative scaling) and could you change the mask creation per https://github.com/huggingface/transformers/pull/40370/files#r2297912836

LGTM otherwise! cc @pstjohn re #40211

@zucchini-nlp
Copy link
Member Author

run-slow: esm, evolla

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/esm', 'models/evolla']
quantizations: [] ...

@zucchini-nlp
Copy link
Member Author

run-slow: esm, evolla

self.scaling = 1.0 # For BC we apply scaling before RoPE
self.is_decoder = config.is_decoder
self.layer_idx = layer_idx
self.is_causal = self.is_decoder # used only in FA2/FA3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sdpa uses this as well :D this is probably incorrect in case of cross-attention, can you add an optional argument here instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, it also has cross attention. Copied from current ESM attention assuming it would be working on main branch

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: esm, evolla

@zucchini-nlp zucchini-nlp merged commit ed5dd29 into huggingface:main Aug 27, 2025
24 checks passed
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_ed923d69-fddc-4960-b3c4-ab304ead727b that referenced this pull request Oct 11, 2025
Original PR #40370 by zucchini-nlp
Original: huggingface/transformers#40370
snorkelopstesting1-a11y added a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_ed923d69-fddc-4960-b3c4-ab304ead727b that referenced this pull request Oct 11, 2025
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_85d8b4dd-85a6-47cb-8943-eb1a51b613b7 that referenced this pull request Oct 11, 2025
Original PR #40370 by zucchini-nlp
Original: huggingface/transformers#40370
snorkelopstesting1-a11y added a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_85d8b4dd-85a6-47cb-8943-eb1a51b613b7 that referenced this pull request Oct 11, 2025
snorkelopstesting2-coder pushed a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_1a79d765-0dc4-4ef3-932b-4d39714a6fe0 that referenced this pull request Oct 11, 2025
Original PR #40370 by zucchini-nlp
Original: huggingface/transformers#40370
snorkelopstesting2-coder added a commit to snorkel-marlin-repos/huggingface_transformers_pr_40370_1a79d765-0dc4-4ef3-932b-4d39714a6fe0 that referenced this pull request Oct 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants