KEMBAR78
FA2 can continue generation from cache by zucchini-nlp · Pull Request #39843 · huggingface/transformers · GitHub
Skip to content

Conversation

@zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Aug 1, 2025

What does this PR do?

Fixes #39814

Don't merge, one of the models fails the test with unknown CUDA-side error and messes up all subsequent tests. Trying to find out with model that is

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -679 to -696
if "flash" in self.config._attn_implementation and self._supports_attention_backend:
tensor_kws = {"dtype": torch.int32, "device": self.device}
pos = model_inputs["position_ids"][:, -1]

cu_seq_lens_k = torch.cat([torch.zeros(1, **tensor_kws), pos.cumsum(0).add(1)], 0)
max_length_k = int(pos.max()) + 1

bs, seq_len = input_ids.size()
q_len = torch.ones(bs, **tensor_kws) if seq_len == 1 else pos.to(torch.int32).add(1)
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0)], 0)
max_length_q = int(q_len.max())

model_inputs.update(
cu_seq_lens_q=cu_seq_lens_q.to(self.device),
cu_seq_lens_k=cu_seq_lens_k.to(self.device),
max_length_q=max_length_q,
max_length_k=max_length_k,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no I don't think it makes sense to never generate them here. It is a lot more efficient to do this once than do it at every layer no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt time the runtime, ig the difference won't be big enough to justify duplicating code. I will check it

IMO not bloating generate() is also important if we can do non-generation processing in other places

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wins aren't big when we precompute, around ~60msec with 256 new tokens generated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am suprised because pad / unpad of the input is costly, and with more batches pretty sure it would deteriorate. It also does not make sense not to have generate create them IMO

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) August 7, 2025 10:01
@alessiodevoto
Copy link

Hi @zucchini-nlp ! Thanks for your effort on this issue! I originally opened this issue (#39814) because of a problem we have in NVIDIA/KVPress when continuing generation after a few forward passes. This is similar to the test you introduced in this PR here .

I tried to run the test locally, with this PR installed, on Llama3.2-1B with FA2 enabled and it failed. I'm copy pasting both the code (should be equivalent to the test but just for Llama) and the output of transformers env.

import torch
import os 

from transformers import AutoTokenizer, AutoModelForCausalLM


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def test_flash_attention_2_continue_generate_with_position_ids():
    """
    Tests that the given attention implementation can work with packed sequences and infers the mask
    from position ids. This test requires the model to use new attention mask API which handles packing.
    """

    max_new_tokens = 2
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", device_map="auto", torch_dtype=torch.bfloat16).eval().cuda()
    model.set_attn_implementation("flash_attention_2")
    inputs_dict = tokenizer("Hello, how are you?", return_tensors="pt")

    # make sure that all models have enough positions for generation
    dummy_input_ids = inputs_dict["input_ids"]
    if hasattr(model.config, "max_position_embeddings"):
        model.config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1

    dummy_input_ids = inputs_dict["input_ids"]
    dummy_position_ids = torch.arange(dummy_input_ids.shape[1], device="cuda")
    dummy_position_ids = dummy_position_ids.unsqueeze(0).repeat(dummy_input_ids.shape[0], 1)

    # Store cache for the input prompt
    output = model(dummy_input_ids.cuda(), position_ids=dummy_position_ids.cuda(), use_cache=True)


    # create new input_ids and position_ids to continue generation re-using the cache
    new_input_ids = output.logits[:, -1, :].float().argmax(-1)[:, None]
    past_length = dummy_input_ids.shape[1]
    position_ids = torch.arange(past_length, past_length + new_input_ids.shape[1], device="cuda")
    position_ids = position_ids.unsqueeze(0).repeat(new_input_ids.shape[0], 1)

    output = model(
        input_ids=new_input_ids,
        past_key_values=output.past_key_values,
        position_ids=position_ids,
        use_cache=True,
    )
    next_token_logits = output.logits[:, -1, :].float()


    generate_kwargs = {
        "pad_token_id": -1,
        "eos_token_id": -1,
        "forced_eos_token_id": None,
        "use_cache": True,
        "do_sample": False,
        "return_dict_in_generate": True,
        "output_logits": True,
        "max_new_tokens": max_new_tokens,
    }
    generation_out = model.generate(dummy_input_ids.cuda(), **generate_kwargs)
    next_token_logits_from_generate = generation_out.logits[-1]



    # acceptable numerical instability
    # print(next_token_logits_from_generate, next_token_logits)
    tol = torch.finfo(torch.bfloat16).eps
    torch.testing.assert_close(next_token_logits_from_generate, next_token_logits, rtol=tol, atol=tol)


test_flash_attention_2_continue_generate_with_position_ids()

This fails sometimes due to differences, other times due to nans or CUDA errors (non deterministically)

Mismatched elements: 127618 / 128256 (99.5%)
Greatest absolute difference: 12.53125 at index (0, 6620) (up to 0.0078125 allowed)
Greatest relative difference: 3075824.0 at index (0, 97390) (up to 0.0078125 allowed)

or

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

transformers env

- `transformers` version: 4.55.0.dev0
- Platform: Linux-6.1.123+-x86_64-with-glibc2.39
- Python version: 3.12.3
- Huggingface_hub version: 0.34.3
- Safetensors version: 0.5.3
- Accelerate version: 1.9.0
- Accelerate config:    not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA H100 80GB HBM3

Do you have any idea why this is happening ?

@alessiodevoto
Copy link

I am also testing whether manual generation is equivalent to using .generate. When I set attn_implementation=eager, they are equivalent, but with FA2 I get again CUDA error. Here is the code:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch
import os 

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def test_generate_are_equivalent():
    """
    Tests that the given attention implementation can work with packed sequences and infers the mask
    from position ids. This test requires the model to use new attention mask API which handles packing.
    """

    max_new_tokens = 20
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", device_map="auto", torch_dtype=torch.bfloat16).eval().cuda()
    model.set_attn_implementation("flash_attention_2")
    inputs_dict = tokenizer("Hello, how are you?", return_tensors="pt")

    # Manual generation using for loop
    print("Manual generation:")
    position_ids = torch.arange(inputs_dict["input_ids"].shape[1], device="cuda")
    position_ids = position_ids.unsqueeze(0).repeat(inputs_dict["input_ids"].shape[0], 1)

    # Store cache for the input prompt
    output = model(inputs_dict["input_ids"].cuda(), position_ids=position_ids.cuda(), use_cache=True)
    past_key_values = output.past_key_values
    manual_tokens = inputs_dict["input_ids"].clone()

    # Generate tokens one by one
    for i in range(max_new_tokens):
        # Get next token
        new_input_ids = output.logits[:, -1, :].float().argmax(-1)[:, None]
        manual_tokens = torch.cat([manual_tokens, new_input_ids.cpu()], dim=1)

        # Print the newly generated token
        new_token_text = tokenizer.decode(new_input_ids[0].cpu().item())
        print(f"Token {i+1}: '{new_token_text}'")

        # Prepare for next iteration
        past_length = manual_tokens.shape[1] - 1
        position_ids = torch.arange(past_length, past_length + 1, device="cuda")
        position_ids = position_ids.unsqueeze(0).repeat(new_input_ids.shape[0], 1)

        output = model(
            input_ids=new_input_ids,
            past_key_values=past_key_values,
            position_ids=position_ids,
            use_cache=True,
        )
        past_key_values = output.past_key_values

    print("\nManual generation complete sequence:")
    print(f"'{tokenizer.decode(manual_tokens[0], skip_special_tokens=True)}'")

    # Generate using model.generate()
    print("\nGenerate method:")
    generate_kwargs = {
        "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1,
        "eos_token_id": -1,
        "forced_eos_token_id": None,
        "use_cache": True,
        "do_sample": False,
        "return_dict_in_generate": True,
        "output_logits": True,
        "max_new_tokens": max_new_tokens,
    }
    generation_out = model.generate(inputs_dict["input_ids"].cuda(), **generate_kwargs)

    # Print tokens from generate method
    generated_tokens = generation_out.sequences[0]
    original_length = inputs_dict["input_ids"].shape[1]

    for i in range(max_new_tokens):
        token_id = generated_tokens[original_length + i].cpu().item()
        token_text = tokenizer.decode(token_id)
        print(f"Token {i+1}: '{token_text}'")

    print("\nGenerate method complete sequence:")
    print(f"'{tokenizer.decode(generated_tokens, skip_special_tokens=True)}'")


test_generate_are_equivalent()

@zucchini-nlp
Copy link
Member Author

Thanks for sharing, taking a look

@zucchini-nlp
Copy link
Member Author

@alessiodevoto can you make sure you are checked out to the current PR branch and it is installed? I ran the scripts you provided and both pass for me on the branch

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the minimal cuseqlens q/k in generate it helps debugging and should be overall more efficient

Comment on lines +229 to +231
# cumulative seq lengths.
if query_length != kv_length:
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is nice to have if else here! taking into account manual generation!

@zucchini-nlp zucchini-nlp merged commit 7892257 into huggingface:main Aug 7, 2025
24 checks passed
@vasqu vasqu mentioned this pull request Aug 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Flash Attention fails with non aligned position_ids

4 participants