-
Notifications
You must be signed in to change notification settings - Fork 2.1k
ABI stable fa3 #1791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ABI stable fa3 #1791
Conversation
|
header only ScalarType is landing! pytorch/pytorch#159416 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some minor nits --> I'm guessing the commented out stuff will be removed?
8249219 to
3bc1203
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's amazing! Thank you for working on this :)
I have a few comments, in the name of making the life of extension maintainers easier, as this will be important to convince libraries (with many many kernels sometimes) to make the switch to stable.
A lot of these changes are essentially just renaming stuff. Would it be possible to have a namespace that would mirror as much as possible the existing class hierarchy?
Like I would like (as a library maintainer) to just set a compile flag -DTORCH_LIMITED_API=0x02090000, and it would hide all non-stable API.
Then maybe I can just set using namespace torch::stable;, and this namespace contains torch::stable::at::ScalarType, torch::stable::at::Tensor, torch::stable::at::cuda::DeviceGuard etc...
hopper/flash_api.cpp
Outdated
| void boxed_mha_fwd_get_scheduler_metadata( | ||
| StableIValue* stack, | ||
| uint64_t num_args, | ||
| uint64_t num_outputs | ||
| ) { | ||
| auto batch_size = to<int64_t>(stack[0]); | ||
| auto max_seqlen_q = to<int64_t>(stack[1]); | ||
| auto max_seqlen_k = to<int64_t>(stack[2]); | ||
| auto num_heads = to<int64_t>(stack[3]); | ||
| auto num_heads_k = to<int64_t>(stack[4]); | ||
| auto headdim = to<int64_t>(stack[5]); | ||
| auto headdim_v = to<int64_t>(stack[6]); | ||
| auto qkv_dtype = to<torch::headeronly::ScalarType>(stack[7]); | ||
| auto seqused_k = to<Tensor>(stack[8]); | ||
| auto cu_seqlens_q = to<std::optional<Tensor>>(stack[9]); | ||
| auto cu_seqlens_k = to<std::optional<Tensor>>(stack[10]); | ||
| auto cu_seqlens_k_new = to<std::optional<Tensor>>(stack[11]); | ||
| auto seqused_q = to<std::optional<Tensor>>(stack[12]); | ||
| auto leftpad_k = to<std::optional<Tensor>>(stack[13]); | ||
| auto page_size = to<std::optional<int64_t>>(stack[14]); | ||
| auto max_seqlen_k_new = to<int64_t>(stack[15]); | ||
| auto is_causal = to<bool>(stack[16]); | ||
| auto window_size_left = to<int64_t>(stack[17]); | ||
| auto window_size_right = to<int64_t>(stack[18]); | ||
| auto attention_chunk = to<int64_t>(stack[19]); | ||
| auto has_softcap = to<bool>(stack[20]); | ||
| auto num_splits = to<int64_t>(stack[21]); | ||
| auto pack_gqa = to<std::optional<bool>>(stack[22]); | ||
| auto sm_margin = to<int64_t>(stack[23]); | ||
|
|
||
| auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin); | ||
|
|
||
| stack[0] = from(scheduler_metadata); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think there is a way to automate this? It could be quite error prone.
Probably we can do something with C++ template magic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want to eventually automate this and hide it from the user completely! We ultimately want to support this in the dispatcher which requires a bit of a lift as we support more IValues. Maybe it could be worth adding an intermediary template solution in the meantime.
hopper/flash_api.cpp
Outdated
| TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); | ||
| STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we avoid having to do this kind of changes? Like could we just replace TORCH_CHECK's definition with STD_TORCH_CHECK's definition - maybe when a given compile flag is set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't yet explored having a compile time flag control what gets exposed, and that will be a part of our consideration when we figure out how to set up ABI target versions and such :D
hopper/flash_api.cpp
Outdated
| std::deque<std::once_flag> device_flags; | ||
| std::vector<cudaDeviceProp> device_properties; | ||
|
|
||
| void initVectors() { | ||
| static bool init_flag [[maybe_unused]] = []() { | ||
| int device_count; | ||
| cudaError_t err = cudaGetDeviceCount(&device_count); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + | ||
| std::string(cudaGetErrorString(err))); | ||
| } | ||
| device_flags.resize(device_count); | ||
| device_properties.resize(device_count); | ||
| return true; | ||
| }(); | ||
| } | ||
|
|
||
| void initDeviceProperty(int device_index) { | ||
| cudaDeviceProp device_prop{}; | ||
| cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + | ||
| std::string(cudaGetErrorString(err))); | ||
| } | ||
| device_properties[device_index] = device_prop; | ||
| } | ||
|
|
||
| // Helper function to get device properties using raw CUDA APIs | ||
| cudaDeviceProp* get_device_prop() { | ||
| initVectors(); | ||
| int device_index; | ||
| cudaError_t err = cudaGetDevice(&device_index); | ||
| if (err != cudaSuccess) { | ||
| STD_TORCH_CHECK(false, "cudaGetDevice failed: " + | ||
| std::string(cudaGetErrorString(err))); | ||
| } | ||
|
|
||
| std::call_once(device_flags[device_index], initDeviceProperty, device_index); | ||
| return &device_properties[device_index]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see multiple libraries could need this kind of functions. Would it be possible to have getCurrentDeviceProperties/getCurrentCUDAStream/... also part of the stable ABI? This should be stable anyway because CUDA ABI should be stable between minor versions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We explicitly chose to support the accelerator agnostic variations of these APIs in stable for 2.9, as we want people to eventually migrate to generic accelerator APIs, but maybe we should revisit this if there's interest! cc @albanD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is an interesting balance between having generic enough APIs vs not.
tbh in this case, if the cuda API was good enough we wouldn't need this...
But to get the current stream, you can use the generic API (and convert the object to cudaStream_t if ever needed).
To get properties it's trickier to get a generic API as each vendor provides different struct there (even though they contain mostly the same thing).
|
Thanks @danthe3rd for your comments! They very much align with where we're headed, especially as @mikaylagawarecki and I start enabling other repos with many kernels soon. That said, the goal of this PR is to update FA3 to be libtorch-stable as soon as possible and still be readable, where the level of readability was based on the initial proposal PR: https://github.com/Dao-AILab/flash-attention/pull/1685/files. To that end, I think this PR does mostly adhere to the initial plan, and so we'd like to verify with @tridao that we can make FA3 libtorch stable 2.9 onwards with the current PR. From there, we do plan on gradually improving the UX with utilities and additional APIs (e.g., for libraries that have more than one kernel :)).
We are exploring ABI version build targets and ABI versioning (cc @mikaylagawarecki), and I think to some extent we want to offer ways for users to opt into only a subset of stable APIs! So yes! And then regarding the namespace change, I think it would be a bit confusing (even if convenient) to expose namespaces like at in torch::stable. We are aiming for a northstar where one day libtorch namespace separations would make sense, and people would use the right APIs from the getgo. We understand that this may make enabling more libraries feel daunting, so we intend to get in the weeds to help library authors in the migration process! |
hopper/flash_api.cpp
Outdated
| } | ||
|
|
||
| TORCH_LIBRARY(flash_attn_3, m) { | ||
| void boxed_mha_fwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand why we need the "boxed" version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a limitation of STABLE_TORCH_LIBRARY for now where it only accepts the boxed kernel. But echoing Jane's comment
We do want to eventually automate this and hide it from the user completely! We ultimately want to support this in the dispatcher which requires a bit of a lift as we support more IValues.
|
The API looks great to me! |
hopper/setup.py
Outdated
|
|
||
| if torch_version_parsed > target_version: | ||
| flash_api_source = "flash_api_stable.cpp" | ||
| stable_args = ["-DTORCH_STABLE_ONLY"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could just append to feature_args? no need to create new args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to pass this to nvcc extra compile args, hence I separated it out
Related to pytorch/pytorch#154908 and #1730
This is an ABI stable version of hopper/flash_api.cpp that should be ABI stable from the 08/30 nightly onwards.
Test results
pytest -q -s test_flash_attn.pyresults onmainpytest -q -s test_flash_attn.pyresults on this branchResults of python benchmark_attn.py
mainthis branch
Note
The
CHECK_SHAPEmacro is rewritten rather weirdly using.size()rather than.sizes()as we do not have a stable version of IntArrayRef yet, but we can fix this in a later PR.cc @janeyx99 @tridao