KEMBAR78
[CUDAGraph] Add getter for cuda graph exec by galv · Pull Request #161294 · pytorch/pytorch · GitHub
Skip to content

Conversation

@galv
Copy link
Collaborator

@galv galv commented Aug 22, 2025

This is far simpler than #155164 since we never destroy the cudaGraphExec_t.

The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like cudaGraphExec*SetParams after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear.

This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work.

Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used.

Testing:

pytest -s -k raw_graph_exec test/test_cuda.py

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng

@galv galv requested review from eqy and syed-ahmed as code owners August 22, 2025 18:42
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161294

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2e92573 with merge base 4651aaa (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@galv galv requested review from BoyuanFeng, eellison, eqy, ngimel and syed-ahmed and removed request for eqy and syed-ahmed August 22, 2025 18:42
graph = torch.cuda.CUDAGraph(keep_graph=False)
x = torch.zeros([2000], device="cuda")
y = torch.ones([2000], device="cuda")
with torch.cuda.graph(graph, capture_error_mode="relaxed"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc why relaxed mode here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't done any "warmup" in this case, and it's possible to run this test before any other test. On the chance that torch.ops.aten.add does some kind of "graph unsafe" operation (I don't remember if this is the case to be honest, but you can see I did this with the keep_graph tests), doing stream capture in "relaxed" mode prevents these unsafe operations from causing an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want me to use warmup first and then do thread local mode, that is fine with me. Just say so. I don't really like to sweat details like this.

not TEST_CUDA_GRAPH or not TEST_CUDA_PYTHON_BINDINGS,
"CUDA >= 11.0 or ROCM >= 5.3 required for graphs, cuda-bindings must be installed",
)
def test_cuda_graph_raw_graph_exec(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this test can be parametrized on keep_graph, there's very little difference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes this is right. Will do.

@BoyuanFeng
Copy link
Contributor

Just curious.

  1. Could you provide a pytorch eager example on how to use this feature with attention ops? Many users would be interested.
  2. For the variable length attention use case. FWIW, users (e.g., vllm) usually capture CUDAGraphs for each batch size as multiples of 8 and pad inputs to nearest batch size. What's the advantage of the proposed feature?
  3. Any latency estimation of updating this config during runtime?

@galv
Copy link
Collaborator Author

galv commented Aug 22, 2025

For the variable length attention use case. FWIW, users (e.g., vllm) usually capture CUDAGraphs for each batch size as multiples of 8 and pad inputs to nearest batch size. What's the advantage of the proposed feature?

That is padding on the batch size. However, it is not padding on the sequence length dimension at all. I am not aware of anyone who builds cuda graphs that are bucketed along sequence length as well at batch size today (but I could be wrong!). Using a SWITCH conditional node would allow us to pick different kernels (which might be the same kernel!) based on different sequence lengths. If you think I am wrong, please let me know, by the way. I am not always able to keep up with all of the decisions people are making in LLM inference software these days.

Could you provide a pytorch eager example on how to use this feature with attention ops? Many users would be interested.

That's not very easy... It would require me to assume that a kernel is robust to changing its parameters. If you consider flash attention, for example, it is common to choose a different template instantiation of a kernel based on the input parameters: https://github.com/vllm-project/flash-attention/blob/57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f/hopper/flash_fwd_combine_launch_template.h#L77-L85 In this situation, we cannot assume that it is safe to change a parameter (though that is not a necessary condition to prove that a kernel is unsafe for changing parameters at runtime).

I will note, BTW, that this code kind of code, where you're basically writing a switch-case statement in C++, is the kind of thing I would like to run with a SWITCH conditional node in a cuda graph at some point.

Since we should not assume anything about any particular aten op, I would have to add a custom C++ op with a kernel that is robust to changing its size parameter. Something like a persistent reduction (by reduction, I mean reducing a vector to a single scalar via a binary operator) would be a toy example. By persistent, I mean that the number of cuda thread blocks is a constant regardless of the size of the input. If it weren't fixed, you would need to know in your update code how to change the blocks when you change the size of the input. This is a source of non-robustness. You are duplicating between your kernel node update code and your original code. The other source of non-robustness is simpler to show with some code:

Let's suppose I need to update my_size in this code:

struct A {
    float val1;
    size_t my_size;
};

__global__ foo_kernel(struct A) { }

We can tell that my_size is at byte 8 of struct A, due to alignment requirements.

But suppose I did an update later:

struct A {
    float val1;
    float val2;
    float val3;
    size_t my_size;
};

__global__ foo_kernel(struct A) { }

Now my_size is at byte 16 of struct A! The update code needs to be updated after you make this update! (you could use offsetof() to get the offset robustly in C++, but that wouldn't work in python code).

Anyway, I can ask my coworkers if they can show a more precise example. No guarantee.

Any latency estimation of updating this config during runtime?

Each call to cudaGraphExecNodeSetParams() takes about 700ns on the CPU. You can hide this behind other GPU code, if you have previous GPU work. You can verify yourself here: https://gist.github.com/galv/866d93a6ebb73b9c29b871e6a3584e80

There is also an extra bit of latency added to cudaGraphLaunch(), where we send commands to the GPU to update its local copy of parameter values in device memory before it starts the cuda graph. This is trivially small. No more than a few microseconds.

@galv galv added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label Aug 22, 2025
@BoyuanFeng BoyuanFeng added ciflow/trunk Trigger trunk jobs on your pull request release notes: cuda release notes category labels Aug 22, 2025
@BoyuanFeng
Copy link
Contributor

Thanks! If an attention example is hard, probably an example for simpler kernels (e.g., addition) is still helpful. This would change how people understand cudagraph x dynamic shapes.

padding on the sequence length

is this for llm training or inference? For inference, it's usually fixed sequence length due to KV Cache, since people preallocate a fixed length KV Cache.

For training, I'm curious how much benefits we can see from CUDAGraph. #150567 prototyped a compiler padding solution for CUDAGraph and transformer. We tried on torchtune (for llm training) but only saw ~3% speedup.

@galv
Copy link
Collaborator Author

galv commented Aug 25, 2025

@BoyuanFeng I am talking about inference in this case. In case you're curious, the padding I'm referring to in vllm is this up to the maximum sequence length for cudagraphs: https://github.com/galv/vllm/blob/787cdb3829676504da2f612fad041db4b8acc271/vllm/worker/model_runner.py#L1011-L1018

BTW, after having done work on cudagraphs in pytorch over the past several months, I have a brief outline on a possible way to do dynamic shapes with cudagraphs in an appropriate way. It will involve work on the cuda driver side, the torch.compile() side, and a little bit of inductor stuff. I can share with you on the pytorch slack if you are curious.

@galv
Copy link
Collaborator Author

galv commented Aug 25, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

galv and others added 2 commits August 25, 2025 15:58
This is far simpler than pytorch#155164 since we never destroy the
cudaGraphExec_t.

The request comes from TRT-LLM specifically. The motivation is that
some power users would like to mutate specific kernel parameters via
APIs like `cudaGraphExec*SetParams` after a cuda graph has been
instantiated. For example, a common request has been to be able to
change the sequence length of attention kernels, after having captured
a graph for the largest possible sequence length. It turns out that
the host overhead you eliminate via cuda graphs in LLM inference ends
up causing an increase in computation time when you size your kernels
to the maximum possible sequence length (which I believe is done in
both TRT-LLM and vLLM). Attention is the most problematic kernel
because its computation time is quadratic in the sequence length,
rather than linear.

This can work if your attention kernel can work for arbitrary
shapes (this is not the case for all attention implementations! Many
of them specialize with templates), and you have a persistent kernel
that allocates only as many blocks as you have SM's (so you don't have
to figure out how many blocks to allocate for a specific sequence
length). Using a conditional SWITCH node is a better generic approach
to this problem, but that requires more infrastructure work.

Note that this requires knowledge of the exact location of the value
in your kernel's parameter buffer to mutate. It won't work with
arbitrary stream capture code whose kernels you don't know before
hand. So I expect this code path to be rarely used.

Testing:

```
pytest -s -k raw_graph_exec test/test_cuda.py
```
@pytorchmergebot
Copy link
Collaborator

Successfully rebased add-raw_cuda_graph-getter-to-cudagraph-viable onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout add-raw_cuda_graph-getter-to-cudagraph-viable && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the add-raw_cuda_graph-getter-to-cudagraph-viable branch from a15000a to 2e92573 Compare August 25, 2025 15:58
@ngimel
Copy link
Collaborator

ngimel commented Aug 25, 2025

@pytorchbot merge -i

@ngimel
Copy link
Collaborator

ngimel commented Aug 25, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
This is far simpler than pytorch#155164 since we never destroy the cudaGraphExec_t.

The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like `cudaGraphExec*SetParams` after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear.

This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work.

Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used.

Testing:

```
pytest -s -k raw_graph_exec test/test_cuda.py
```

Pull Request resolved: pytorch#161294
Approved by: https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels open source release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants