KEMBAR78
Allocate more shared memory to attention kernel by Yard1 · Pull Request #1154 · vllm-project/vllm · GitHub
Skip to content

Conversation

@Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Sep 23, 2023

Makes use of additional shared memory present on compute capability >=7.0 cards to support longer context length in the attention kernel.

See https://stackoverflow.com/questions/63757245/using-maximum-shared-memory-in-cuda for details.

As pointed out by @WoosukKwon offline, ideally we would also store logits inside the kernel in float16 instead of float32 as the accuracy loss should be minimal. This will enable even longer context lengths.

Note that the buffer of 512 * sizeof(float32) may be too conservative, but this is still going to result in more supported tokens than ~11k previously. The attention test has been ran on A10 and A100 successfully.

With this PR, the supported context lengths with current kernel (float32 logits) will be:

  • CC 7.5 (Turing): 64KiB shared memory -> 16328 tokens
  • CC 7.0 (Volta): 96KiB shared memory -> 24984 tokens
  • CC 8.6 (Ampere A10): 100KiB shared memory -> 25128 tokens
  • CC 8.0 (Ampere A100): 160KiB shared memory -> 39936 tokens
  • CC 9.0 (Hopper): 227KiB shared memory -> 57600 tokens

Closes #905

@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 23, 2023

cc @WoosukKwon @LiuXiaoxuanPKU

@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 23, 2023

We could also add a python assertion for this? added

@Yard1 Yard1 marked this pull request as draft September 23, 2023 02:09
@Yard1 Yard1 marked this pull request as ready for review September 23, 2023 02:31
@esmeetu
Copy link
Member

esmeetu commented Sep 23, 2023

hi, @Yard1 I have a question here, if i using dtype=float16 for model inference, does it will affect accuracy when changing buffer logits from float32 to float16 to support longer context?

@Yard1
Copy link
Collaborator Author

Yard1 commented Sep 23, 2023

I am not sure, @WoosukKwon would know best.

@WoosukKwon WoosukKwon self-requested a review September 25, 2023 23:36
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 Thanks for the quick fix! I'm a bit worried about the performance since we manually adjusted the shared memory size, but it seems the performance does not change by the fix. 👍

Left some questions and comments. Please take a look.

MAX_SEQ_LEN = 8192
float_bytes = torch.finfo(torch.float).bits / 8
# This will change dependning on the compute capability.
# -7 as it will be padded to 8 anyway, -512 as a buffer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on this?

-7 as it will be padded to 8 anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question: How did you choose 512 for the buffer size?

Copy link
Collaborator Author

@Yard1 Yard1 Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got (64 + 256) * sizeof(float32) for other __shared__ variables by reading the CUDA kernel, so I just rounded it up to 512 * sizeof(float32) to be safe. But it may be too conservative.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon would appreciate if you could provide a more accurate measurement :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 Do you mean

__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];

and

__shared__ float red_smem[2 * NUM_WARPS];

?

Copy link
Collaborator Author

@Yard1 Yard1 Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - they are included in the shared memory usage - we should set the buffer to upper bound of those

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my calculation is correct, the size of q_vecs is head_size * sizeof(scalar_t) <= 256 * 4 = 1024. The size of red_smem is obviously 64 * 4 = 256. In total, it's 1280 bytes (=320 float elements). So 512 is actually a bit conservative upper bound. However, I think this is acceptable.

vllm/utils.py Outdated
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_mem_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this always 4?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should technically be, but that way it ensures it's always true irrespective of the platform/implementation and is also self documenting

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my knowledge, the size of float is defined as an IEEE standard and is independent from the underlying machine architecture (unlike integer types). That being said, I like that this is self-documenting. Let's keep it!

@Yard1 Yard1 requested a review from WoosukKwon September 26, 2023 21:49
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 LGTM! Thanks again for the PR! Left very minor style issues. Please fix them before merge.

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@esmeetu
Copy link
Member

esmeetu commented Sep 27, 2023

@Yard1 Great! I tested long prompt using this PR. It doesn't crash any more until now.
And i think _check_if_can_support_max_seq_len should check minimum of (max_num_batched_tokens, max_seq_len). If not, it will always get check error in my GPU.
Example:

max_seq_len:16384,
block_size:16,
max_shared_mem:65536,
float32_bytes:4,
padded_max_seq_len:16399.0, required_shared_mem:67644.0

@WoosukKwon
Copy link
Collaborator

Hi @esmeetu, Thanks for reporting the issue. I think that's related to how to set max_num_batched_tokens, and thus can be handled in a separate PR.

@WoosukKwon WoosukKwon merged commit cf5cb1e into vllm-project:main Sep 27, 2023
@esmeetu
Copy link
Member

esmeetu commented Sep 27, 2023

@WoosukKwon I didn't get what you mean for how to set that parameter. Doesn't it being set by schedule config?🤔️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

vLLM doesn't support context length exceeding about 13k

3 participants