KEMBAR78
[Graph Partition] allow sharing default device context by BoyuanFeng · Pull Request #162873 · pytorch/pytorch · GitHub
Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Sep 13, 2025

Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
image

The second trace shows inductor graph partition prior to this PR.
image

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
image

cc @mcarilli @ezyang @eellison @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@BoyuanFeng BoyuanFeng added module: cuda graphs Ability to capture and then replay streams of CUDA kernels ciflow/trunk Trigger trunk jobs on your pull request release notes: inductor ci-no-td Do not run TD on this PR labels Sep 13, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162873

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2bb9647 with merge base d2f6daf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# is important for nested subgraph codegening.
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
self.write_get_raw_stream_header_once()
self.write_get_raw_stream_header()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

standalone_compile may need additional imports of get_raw_stream

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why additional imports are needed? Can those be added as cache key if repeated imports are needed

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need imports of get_raw_stream for both subgraphs and the main graphs, including the final output code and autotune_at_compile_time blocks.

Currently, write_get_raw_stream_header_once relies on cache_on_self to import once. When generating subgraph, it calls SubgraphPythonWrapperCodegen.write_get_raw_stream_header_once() (cache_on_self) which calls PythonWrapperCodegen.write_get_raw_stream_header_once() (cache_on_self).

@cache_on_self
def write_get_raw_stream_header_once(self) -> None:
# TODO: Uncomment in future. This will be needed to support subgraph
# codegen for cpp wrapper.
# if config.triton.autotune_at_compile_time:
# self.kernel_autotune_calls.writeline(
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()

@cache_on_self
def write_get_raw_stream_header_once(self) -> None:
self.write_get_raw_stream_header()

As a result, the main graph will skip PythonWrapperCodegen.write_get_raw_stream_header_once() (cache_on_self) since it has just been called. However, this leads to errors since the main graph code also needs this import.

So I call write_get_raw_stream_header here whenever we need it, and only write to self.imports when self.imports does not contain import_get_raw_stream_str.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, write_get_raw_stream_header_once relies on cache_on_self to import once. When generating subgraph, it calls SubgraphPythonWrapperCodegen.write_get_raw_stream_header_once() (cache_on_self) which calls PythonWrapperCodegen.write_get_raw_stream_header_once() (cache_on_self).

Is the import for the subgraph wrapper not added to the global scope?

Copy link
Contributor Author

@BoyuanFeng BoyuanFeng Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. A new wrapper code is generated for each subgraph.

def _codegen_partition_wrapper(
self,
partition: PartitionType,
signature: GraphPartitionSignature,
) -> None:
"""Codegen a partition given its inputs/outputs"""
from .codegen.wrapper import SubgraphPythonWrapperCodegen
parent_wrapper_code = V.graph.wrapper_code
graph_partition_id = next(self._graph_partition_counter)
with V.graph.set_current_wrapper_code():
V.graph.init_wrapper_code(
is_subgraph=True,
subgraph_name=f"partition_{graph_partition_id}",
parent_wrapper_code=parent_wrapper_code,
partition_signatures=signature,
)
self._codegen(partition)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. A new wrapper code is generated for each subgraph.

But all wrapper code still dump to the same .py file right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes for the output_code. But no for config.triton.autotune_at_compile_time, which creates a separate code to run IIUC.

Copy link
Contributor

@shunting314 shunting314 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice finding!

Is it possible to improve device context manager instead? i.e., if we are already on the mentioned device, make it cheaper to enter/exit

Also any number for the overall improvement for inference?

# is important for nested subgraph codegening.
def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str:
self.write_get_raw_stream_header_once()
self.write_get_raw_stream_header()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why additional imports are needed? Can those be added as cache key if repeated imports are needed

@BoyuanFeng
Copy link
Contributor Author

BoyuanFeng commented Sep 15, 2025

Is it possible to improve device context manager instead? i.e., if we are already on the mentioned device, make it cheaper to enter/exit

Prior to this PR, we actually only codegen enter/exit a device context if the device has changed.

if device != self.current_device:
if self.current_device and device_need_guard(
self.current_device.type
):
V.graph.wrapper_code.codegen_device_guard_exit()
self.current_device = device
if device_need_guard(device.type):
assert device.index is not None, "device should have an index"
V.graph.wrapper_code.codegen_device_guard_enter(device.index)

The issue is the assumption that _codegen always starts with None device. This is true in general for non graph partition cases: We only have 1 graph so we start from cpu devices and need enter device context whenever the device has changed.

However, this assumption is NOT true for graph partitions. We have many graph partitions so we can reuse existing device context. The major purpose of this PR is to modify the assumption a bit.

self.current_device = None

@BoyuanFeng
Copy link
Contributor Author

Also any number for the overall improvement for inference?

For vllm Qwen/Qwen3-0.6B, this PR reduces latency from 0.34 seconds to 0.328 seconds, around 3.5% speedup.

Command to repro:
vllm bench latency -O.splitting_ops=[] -O.cudagraph_mode=PIECEWISE -O.use_inductor_graph_partition=true

We can easily see the speedup from trace. torch/_inductor/output_code.py __call__ reduces from 5.5 ms to 5 ms.

trace before the pr vs trace after the pr

@shunting314
Copy link
Contributor

shunting314 commented Sep 16, 2025

For vllm Qwen/Qwen3-0.6B, this PR reduces latency from 0.34 seconds to 0.328 seconds, around 3.5% speedup.

Want to dive a bit more on the benchmarking. How many subgraphs do we generate for Qwen/Qwen3-0.6B? Let's say we save X device.enter/exit each time we call the wrapper. Assuming a pair of device.enter/exit cost Y us (~40us according to your measurement). Then the total saving for a wrapper call is (X * Y) us. (X * Y) / 3.5% should roughly match the latency of one inference call.

Also what blocks us from going one step further to call device.enter/exit only once and do multiple wrapper calls.

@shunting314 shunting314 self-requested a review September 16, 2025 00:39
@shunting314
Copy link
Contributor

btw, the attention kernel seems not being captured in the cuda graphs according to the gap seem from the trace

Screenshot 2025-09-15 at 5 46 02 PM

@BoyuanFeng
Copy link
Contributor Author

btw, the attention kernel seems not being captured in the cuda graphs according to the gap seem from the trace

This is expected. The attention kernel is cudagraph unsafe. so it is explicitly marked with torch._C.Tag.cudagraph_unsafe. More docs here.

@BoyuanFeng
Copy link
Contributor Author

what blocks us from going one step further to call device.enter/exit only once and do multiple wrapper calls.

Yes we currently only call device.enter/exit once with multiple wrapper calls. tlparse and output code. There are 29 partitions and 28 torch.ops.vllm.unified_attention_with_output.defaults but only 1 with torch.cuda._DeviceGuard(0): in def call().

@BoyuanFeng
Copy link
Contributor Author

For Qwen3-0.6B.

Before this PR:
image

After this PR:
image

Overall, output_code call reduces from 5.5 ms to 5 ms. Since there are 29 subgraphs and 28 attention ops, we expect each subgraph/attention ops save 500/28=18 us. From the trace, we see around (10+27)-(4+9) = 24 us saving. Note that the saving changes a bit across attention&partition calls.

This 10% saving is larger than the end-to-end 3.5% saving. Because the model also have several other components such as update_states, prepare_inputs, etc.

@shunting314
Copy link
Contributor

ok. didn't realize that there are so many partitions due to not capture attention kernels in cudagraphs

@BoyuanFeng
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

## Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />

The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />

Pull Request resolved: pytorch#162873
Approved by: https://github.com/shunting314
@BoyuanFeng BoyuanFeng added this to the 2.9.0 milestone Sep 19, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

## Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />

The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />

Pull Request resolved: pytorch#162873
Approved by: https://github.com/shunting314
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

## Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />

The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />

Pull Request resolved: pytorch#162873
Approved by: https://github.com/shunting314
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Entering a device context takes 30 us and exiting a device context takes 11 us. If all graph partitions and cudagraph-unsafe ops happen on the same device, we can share the device context.

## Trace

Use vLLM as an example. The first trace shows dynamo graph partition.
<img width="1338" height="453" alt="image" src="https://github.com/user-attachments/assets/b81815fd-cdcb-4024-846a-5b64164f8bac" />

The second trace shows inductor graph partition prior to this PR.
<img width="1331" height="270" alt="image" src="https://github.com/user-attachments/assets/8d98b127-2053-4eae-9a31-5491661f14d8" />

Comparing with fx graph partition, we can see inductor graph partition shows extra overhead from enter/exit device contexts (13+6 us -> 30+11 us), but smaller runtime overhead (13 us -> 7 us). This motivates the PR to share default device context.

The third trace shows Inductor graph partition after this PR. We observe that the extra overhead from enter/exit device contexts have been fixed. At the same time, we observe the smaller runtime overhead.
<img width="1336" height="276" alt="image" src="https://github.com/user-attachments/assets/77be2237-34dd-4bac-ad9c-d9af3be36417" />

Pull Request resolved: pytorch#162873
Approved by: https://github.com/shunting314
@Camyll
Copy link
Contributor

Camyll commented Oct 1, 2025

@pytorchbot cherry-pick --onto release/2.9 --c critical

@pytorchbot
Copy link
Collaborator

Cherry picking #162873

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 505ee42570dd45247ecc77b65e44399f43777ea5 returned non-zero exit code 1

Auto-merging test/inductor/test_torchinductor.py
Auto-merging torch/_inductor/codegen/wrapper.py
Auto-merging torch/_inductor/scheduler.py
Auto-merging torch/_inductor/utils.py
The previous cherry-pick is now empty, possibly due to conflict resolution.
If you wish to commit it anyway, use:

    git commit --allow-empty

Otherwise, please use 'git cherry-pick --skip'
On branch cherry-pick-162873-by-pytorch_bot_bot_
You are currently cherry-picking commit 505ee42570d.
  (all conflicts fixed: run "git cherry-pick --continue")
  (use "git cherry-pick --skip" to skip this patch)
  (use "git cherry-pick --abort" to cancel the cherry-pick operation)

nothing to commit, working tree clean
Details for Dev Infra team Raised by workflow job

@Camyll
Copy link
Contributor

Camyll commented Oct 1, 2025

Resolved in #163097

no need to cherry pick

@atalman atalman removed this from the 2.9.0 milestone Oct 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants