KEMBAR78
ABI stable fa3 by mikaylagawarecki · Pull Request #1791 · Dao-AILab/flash-attention · GitHub
Skip to content

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Jul 31, 2025

Related to pytorch/pytorch#154908 and #1730

This is an ABI stable version of hopper/flash_api.cpp that should be ABI stable from the 08/30 nightly onwards.

pip install --pre torch==2.9.0.dev20250830+cu129 --index-url https://download.pytorch.org/whl/nightly/cu129

Test results

pytest -q -s test_flash_attn.py results on main

pytest -q -s test_flash_attn.py results on this branch

Results of python benchmark_attn.py

main

### headdim = 128, causal = False, seqlen = 8192 ###
Fav2 fwd: 2.972ms, 369.9 TFLOPS
Fav2 bwd: 9.460ms, 290.6 TFLOPS
Fav3 fwd: 1.668ms, 659.1 TFLOPS
Fav3 bwd: 4.962ms, 553.9 TFLOPS

### headdim = 128, causal = True, seqlen = 8192 ###
Fav2 fwd: 1.935ms, 284.1 TFLOPS
Fav2 bwd: 4.788ms, 287.1 TFLOPS
Fav3 fwd: 0.873ms, 630.0 TFLOPS
Fav3 bwd: 2.585ms, 531.6 TFLOPS

this branch

### headdim = 128, causal = False, seqlen = 8192 ###
Fav2 fwd: 2.963ms, 371.1 TFLOPS
Fav2 bwd: 9.515ms, 288.9 TFLOPS
Fav3 fwd: 1.667ms, 659.6 TFLOPS
Fav3 bwd: 4.878ms, 563.5 TFLOPS

### headdim = 128, causal = True, seqlen = 8192 ###
Fav2 fwd: 1.954ms, 281.4 TFLOPS
Fav2 bwd: 4.755ms, 289.1 TFLOPS
Fav3 fwd: 0.872ms, 630.5 TFLOPS
Fav3 bwd: 2.586ms, 531.5 TFLOPS

Note

The CHECK_SHAPE macro is rewritten rather weirdly using .size() rather than .sizes() as we do not have a stable version of IntArrayRef yet, but we can fix this in a later PR.

cc @janeyx99 @tridao

@mikaylagawarecki mikaylagawarecki changed the title [skip_ci] ABI stable [skip_ci] ABI stable fa3 Jul 31, 2025
@janeyx99
Copy link
Contributor

header only ScalarType is landing! pytorch/pytorch#159416

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some minor nits --> I'm guessing the commented out stuff will be removed?

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review August 20, 2025 22:01
@mikaylagawarecki mikaylagawarecki changed the title [skip_ci] ABI stable fa3 ABI stable fa3 Aug 26, 2025
Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's amazing! Thank you for working on this :)
I have a few comments, in the name of making the life of extension maintainers easier, as this will be important to convince libraries (with many many kernels sometimes) to make the switch to stable.

A lot of these changes are essentially just renaming stuff. Would it be possible to have a namespace that would mirror as much as possible the existing class hierarchy?
Like I would like (as a library maintainer) to just set a compile flag -DTORCH_LIMITED_API=0x02090000, and it would hide all non-stable API.
Then maybe I can just set using namespace torch::stable;, and this namespace contains torch::stable::at::ScalarType, torch::stable::at::Tensor, torch::stable::at::cuda::DeviceGuard etc...

Comment on lines 1844 to 1877
void boxed_mha_fwd_get_scheduler_metadata(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs
) {
auto batch_size = to<int64_t>(stack[0]);
auto max_seqlen_q = to<int64_t>(stack[1]);
auto max_seqlen_k = to<int64_t>(stack[2]);
auto num_heads = to<int64_t>(stack[3]);
auto num_heads_k = to<int64_t>(stack[4]);
auto headdim = to<int64_t>(stack[5]);
auto headdim_v = to<int64_t>(stack[6]);
auto qkv_dtype = to<torch::headeronly::ScalarType>(stack[7]);
auto seqused_k = to<Tensor>(stack[8]);
auto cu_seqlens_q = to<std::optional<Tensor>>(stack[9]);
auto cu_seqlens_k = to<std::optional<Tensor>>(stack[10]);
auto cu_seqlens_k_new = to<std::optional<Tensor>>(stack[11]);
auto seqused_q = to<std::optional<Tensor>>(stack[12]);
auto leftpad_k = to<std::optional<Tensor>>(stack[13]);
auto page_size = to<std::optional<int64_t>>(stack[14]);
auto max_seqlen_k_new = to<int64_t>(stack[15]);
auto is_causal = to<bool>(stack[16]);
auto window_size_left = to<int64_t>(stack[17]);
auto window_size_right = to<int64_t>(stack[18]);
auto attention_chunk = to<int64_t>(stack[19]);
auto has_softcap = to<bool>(stack[20]);
auto num_splits = to<int64_t>(stack[21]);
auto pack_gqa = to<std::optional<bool>>(stack[22]);
auto sm_margin = to<int64_t>(stack[23]);

auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin);

stack[0] = from(scheduler_metadata);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there is a way to automate this? It could be quite error prone.
Probably we can do something with C++ template magic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want to eventually automate this and hide it from the user completely! We ultimately want to support this in the dispatcher which requires a bit of a lift as we support more IValues. Maybe it could be worth adding an intermediary template solution in the meantime.

Comment on lines 1143 to 1206
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid having to do this kind of changes? Like could we just replace TORCH_CHECK's definition with STD_TORCH_CHECK's definition - maybe when a given compile flag is set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't yet explored having a compile time flag control what gets exposed, and that will be a part of our consideration when we figure out how to set up ABI target versions and such :D

Comment on lines 31 to 70
std::deque<std::once_flag> device_flags;
std::vector<cudaDeviceProp> device_properties;

void initVectors() {
static bool init_flag [[maybe_unused]] = []() {
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_flags.resize(device_count);
device_properties.resize(device_count);
return true;
}();
}

void initDeviceProperty(int device_index) {
cudaDeviceProp device_prop{};
cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_properties[device_index] = device_prop;
}

// Helper function to get device properties using raw CUDA APIs
cudaDeviceProp* get_device_prop() {
initVectors();
int device_index;
cudaError_t err = cudaGetDevice(&device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDevice failed: " +
std::string(cudaGetErrorString(err)));
}

std::call_once(device_flags[device_index], initDeviceProperty, device_index);
return &device_properties[device_index];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see multiple libraries could need this kind of functions. Would it be possible to have getCurrentDeviceProperties/getCurrentCUDAStream/... also part of the stable ABI? This should be stable anyway because CUDA ABI should be stable between minor versions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We explicitly chose to support the accelerator agnostic variations of these APIs in stable for 2.9, as we want people to eventually migrate to generic accelerator APIs, but maybe we should revisit this if there's interest! cc @albanD

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is an interesting balance between having generic enough APIs vs not.
tbh in this case, if the cuda API was good enough we wouldn't need this...

But to get the current stream, you can use the generic API (and convert the object to cudaStream_t if ever needed).
To get properties it's trickier to get a generic API as each vendor provides different struct there (even though they contain mostly the same thing).

@janeyx99
Copy link
Contributor

Thanks @danthe3rd for your comments! They very much align with where we're headed, especially as @mikaylagawarecki and I start enabling other repos with many kernels soon. That said, the goal of this PR is to update FA3 to be libtorch-stable as soon as possible and still be readable, where the level of readability was based on the initial proposal PR: https://github.com/Dao-AILab/flash-attention/pull/1685/files.

To that end, I think this PR does mostly adhere to the initial plan, and so we'd like to verify with @tridao that we can make FA3 libtorch stable 2.9 onwards with the current PR. From there, we do plan on gradually improving the UX with utilities and additional APIs (e.g., for libraries that have more than one kernel :)).

A lot of these changes are essentially just renaming stuff. Would it be possible to have a namespace that would mirror as much as possible the existing class hierarchy?
Like I would like (as a library maintainer) to just set a compile flag -DTORCH_LIMITED_API=0x02090000, and it would hide all non-stable API.
Then maybe I can just set using namespace torch::stable;, and this namespace contains torch::stable::at::ScalarType, torch::stable::at::Tensor, torch::stable::at::cuda::DeviceGuard etc...

We are exploring ABI version build targets and ABI versioning (cc @mikaylagawarecki), and I think to some extent we want to offer ways for users to opt into only a subset of stable APIs! So yes! And then regarding the namespace change, I think it would be a bit confusing (even if convenient) to expose namespaces like at in torch::stable. We are aiming for a northstar where one day libtorch namespace separations would make sense, and people would use the right APIs from the getgo. We understand that this may make enabling more libraries feel daunting, so we intend to get in the weeds to help library authors in the migration process!

}

TORCH_LIBRARY(flash_attn_3, m) {
void boxed_mha_fwd(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand why we need the "boxed" version

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a limitation of STABLE_TORCH_LIBRARY for now where it only accepts the boxed kernel. But echoing Jane's comment

We do want to eventually automate this and hide it from the user completely! We ultimately want to support this in the dispatcher which requires a bit of a lift as we support more IValues.

@tridao
Copy link
Member

tridao commented Sep 3, 2025

The API looks great to me!

hopper/setup.py Outdated

if torch_version_parsed > target_version:
flash_api_source = "flash_api_stable.cpp"
stable_args = ["-DTORCH_STABLE_ONLY"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could just append to feature_args? no need to create new args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to pass this to nvcc extra compile args, hence I separated it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants