-
Notifications
You must be signed in to change notification settings - Fork 30.9k
FA2 can continue generation from cache #39843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if "flash" in self.config._attn_implementation and self._supports_attention_backend: | ||
| tensor_kws = {"dtype": torch.int32, "device": self.device} | ||
| pos = model_inputs["position_ids"][:, -1] | ||
|
|
||
| cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0) | ||
| max_length_k = int(pos.max()) + 1 | ||
|
|
||
| bs, seq_len = input_ids.size() | ||
| q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1) | ||
| cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0) | ||
| max_length_q = int(q_len.max()) | ||
|
|
||
| model_inputs.update( | ||
| cu_seq_lens_q=cu_seq_lens_q.to(self.device), | ||
| cu_seq_lens_k=cu_seq_lens_k.to(self.device), | ||
| max_length_q=max_length_q, | ||
| max_length_k=max_length_k, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no I don't think it makes sense to never generate them here. It is a lot more efficient to do this once than do it at every layer no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didnt time the runtime, ig the difference won't be big enough to justify duplicating code. I will check it
IMO not bloating generate() is also important if we can do non-generation processing in other places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wins aren't big when we precompute, around ~60msec with 256 new tokens generated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am suprised because pad / unpad of the input is costly, and with more batches pretty sure it would deteriorate. It also does not make sense not to have generate create them IMO
|
Hi @zucchini-nlp ! Thanks for your effort on this issue! I originally opened this issue (#39814) because of a problem we have in NVIDIA/KVPress when continuing generation after a few forward passes. This is similar to the test you introduced in this PR here . I tried to run the test locally, with this PR installed, on Llama3.2-1B with FA2 enabled and it failed. I'm copy pasting both the code (should be equivalent to the test but just for Llama) and the output of import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def test_flash_attention_2_continue_generate_with_position_ids():
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
"""
max_new_tokens = 2
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="auto", torch_dtype=torch.bfloat16).eval().cuda()
model.set_attn_implementation("flash_attention_2")
inputs_dict = tokenizer("Hello, how are you?", return_tensors="pt")
# make sure that all models have enough positions for generation
dummy_input_ids = inputs_dict["input_ids"]
if hasattr(model.config, "max_position_embeddings"):
model.config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
dummy_input_ids = inputs_dict["input_ids"]
dummy_position_ids = torch.arange(dummy_input_ids.shape[1], device="cuda")
dummy_position_ids = dummy_position_ids.unsqueeze(0).repeat(dummy_input_ids.shape[0], 1)
# Store cache for the input prompt
output = model(dummy_input_ids.cuda(), position_ids=dummy_position_ids.cuda(), use_cache=True)
# create new input_ids and position_ids to continue generation re-using the cache
new_input_ids = output.logits[:, -1, :].float().argmax(-1)[:, None]
past_length = dummy_input_ids.shape[1]
position_ids = torch.arange(past_length, past_length + new_input_ids.shape[1], device="cuda")
position_ids = position_ids.unsqueeze(0).repeat(new_input_ids.shape[0], 1)
output = model(
input_ids=new_input_ids,
past_key_values=output.past_key_values,
position_ids=position_ids,
use_cache=True,
)
next_token_logits = output.logits[:, -1, :].float()
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_logits": True,
"max_new_tokens": max_new_tokens,
}
generation_out = model.generate(dummy_input_ids.cuda(), **generate_kwargs)
next_token_logits_from_generate = generation_out.logits[-1]
# acceptable numerical instability
# print(next_token_logits_from_generate, next_token_logits)
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(next_token_logits_from_generate, next_token_logits, rtol=tol, atol=tol)
test_flash_attention_2_continue_generate_with_position_ids()This fails sometimes due to differences, other times due to nans or CUDA errors (non deterministically) or transformers envDo you have any idea why this is happening ? |
|
I am also testing whether manual generation is equivalent to using from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def test_generate_are_equivalent():
"""
Tests that the given attention implementation can work with packed sequences and infers the mask
from position ids. This test requires the model to use new attention mask API which handles packing.
"""
max_new_tokens = 20
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", device_map="auto", torch_dtype=torch.bfloat16).eval().cuda()
model.set_attn_implementation("flash_attention_2")
inputs_dict = tokenizer("Hello, how are you?", return_tensors="pt")
# Manual generation using for loop
print("Manual generation:")
position_ids = torch.arange(inputs_dict["input_ids"].shape[1], device="cuda")
position_ids = position_ids.unsqueeze(0).repeat(inputs_dict["input_ids"].shape[0], 1)
# Store cache for the input prompt
output = model(inputs_dict["input_ids"].cuda(), position_ids=position_ids.cuda(), use_cache=True)
past_key_values = output.past_key_values
manual_tokens = inputs_dict["input_ids"].clone()
# Generate tokens one by one
for i in range(max_new_tokens):
# Get next token
new_input_ids = output.logits[:, -1, :].float().argmax(-1)[:, None]
manual_tokens = torch.cat([manual_tokens, new_input_ids.cpu()], dim=1)
# Print the newly generated token
new_token_text = tokenizer.decode(new_input_ids[0].cpu().item())
print(f"Token {i+1}: '{new_token_text}'")
# Prepare for next iteration
past_length = manual_tokens.shape[1] - 1
position_ids = torch.arange(past_length, past_length + 1, device="cuda")
position_ids = position_ids.unsqueeze(0).repeat(new_input_ids.shape[0], 1)
output = model(
input_ids=new_input_ids,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=True,
)
past_key_values = output.past_key_values
print("\nManual generation complete sequence:")
print(f"'{tokenizer.decode(manual_tokens[0], skip_special_tokens=True)}'")
# Generate using model.generate()
print("\nGenerate method:")
generate_kwargs = {
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_logits": True,
"max_new_tokens": max_new_tokens,
}
generation_out = model.generate(inputs_dict["input_ids"].cuda(), **generate_kwargs)
# Print tokens from generate method
generated_tokens = generation_out.sequences[0]
original_length = inputs_dict["input_ids"].shape[1]
for i in range(max_new_tokens):
token_id = generated_tokens[original_length + i].cpu().item()
token_text = tokenizer.decode(token_id)
print(f"Token {i+1}: '{token_text}'")
print("\nGenerate method complete sequence:")
print(f"'{tokenizer.decode(generated_tokens, skip_special_tokens=True)}'")
test_generate_are_equivalent() |
|
Thanks for sharing, taking a look |
|
@alessiodevoto can you make sure you are checked out to the current PR branch and it is installed? I ran the scripts you provided and both pass for me on the branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the minimal cuseqlens q/k in generate it helps debugging and should be overall more efficient
| # cumulative seq lengths. | ||
| if query_length != kv_length: | ||
| indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is nice to have if else here! taking into account manual generation!
What does this PR do?
Fixes #39814
Don't merge, one of the models fails the test with unknown CUDA-side error and messes up all subsequent tests. Trying to find out with model that is