KEMBAR78
allow user to pass in custom partitioner function by xuanzhang816 · Pull Request #157580 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xuanzhang816
Copy link
Contributor

@xuanzhang816 xuanzhang816 commented Jul 3, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157580

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit be4d1e9 with merge base f4c33cd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xuanzhang816 added a commit that referenced this pull request Jul 3, 2025
ghstack-source-id: ee66afd
Pull Request resolved: #157580
@xuanzhang816 xuanzhang816 marked this pull request as draft July 3, 2025 19:30
@xuanzhang816
Copy link
Contributor Author

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 3, 2025
@xuanzhang816
Copy link
Contributor Author

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuanzhang816
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@xuanzhang816 xuanzhang816 marked this pull request as ready for review July 7, 2025 22:24
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jul 7, 2025
@mlazos
Copy link
Contributor

mlazos commented Jul 8, 2025

Should you add a test for this?

@mlazos mlazos self-requested a review July 8, 2025 00:48
@xuanzhang816 xuanzhang816 requested a review from bdhirsh July 8, 2025 14:47
@bdhirsh bdhirsh requested review from a team and zou3519 July 8, 2025 14:52
@bdhirsh
Copy link
Contributor

bdhirsh commented Jul 8, 2025

also cc @zou3519, this is the "custom partitioner" we talked about



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D77753038](https://our.internmc.facebook.com/intern/diff/D77753038)

[ghstack-poisoned]
xuanzhang816 added a commit that referenced this pull request Jul 16, 2025
ghstack-source-id: ad9c69c
Pull Request resolved: #157580
@xuanzhang816
Copy link
Contributor Author

@bdhirsh I did a rewrite of this to ensure cacheability, copying the idea of custom graph pass, and to add tests.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D77753038](https://our.internmc.facebook.com/intern/diff/D77753038)

[ghstack-poisoned]
xuanzhang816 added a commit that referenced this pull request Jul 16, 2025
ghstack-source-id: 70c0541
Pull Request resolved: #157580
[ghstack-poisoned]
xuanzhang816 added a commit that referenced this pull request Sep 4, 2025
ghstack-source-id: 19e89c5
Pull Request resolved: #157580
[ghstack-poisoned]
xuanzhang816 added a commit that referenced this pull request Sep 5, 2025
ghstack-source-id: a27df73
Pull Request resolved: #157580
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating!

[ghstack-poisoned]
xuanzhang816 added a commit that referenced this pull request Sep 5, 2025
ghstack-source-id: a075bd3
Pull Request resolved: #157580
@xuanzhang816
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

daisyden pushed a commit to daisyden/pytorch that referenced this pull request Sep 8, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
eternalNight added a commit to openanolis/DeepSpeed that referenced this pull request Sep 26, 2025
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can switch to min_cut_rematerialization
easily. For eager backend, this can be done by passing
`min_cut_rematerialization_partition` as the partition_fn for
`aot_module_simplified`. As for the inductor backend, it always uses
that algorithm till torch 2.8.0 and is the default since the inductor
partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. With
autocast enabled, downcasted parameters are preferred to be recomputed.
Again finding such casting nodes and make them must-recompute suffices.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281
[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
eternalNight added a commit to openanolis/DeepSpeed that referenced this pull request Sep 29, 2025
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can switch to min_cut_rematerialization
easily. For eager backend, this can be done by passing
`min_cut_rematerialization_partition` as the partition_fn for
`aot_module_simplified`. The inductor backend uses that algorithm since
torch 2.0.0 and is still the default after the inductor partitioner is
made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. With
autocast enabled, downcasted parameters are preferred to be recomputed.
Again finding such casting nodes and make them must-recompute suffices.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281
[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
eternalNight added a commit to openanolis/DeepSpeed that referenced this pull request Sep 30, 2025
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813

[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281

[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
eternalNight added a commit to openanolis/DeepSpeed that referenced this pull request Oct 3, 2025
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813

[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281

[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
eternalNight added a commit to openanolis/DeepSpeed that referenced this pull request Oct 3, 2025
PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization suits our needs quite well. When partitioning
a joint graph, we don't want to save for backward the gathered
parameters and values computed from them via aliasing ops, as that
essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching choose_saved_values_set, we can
achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

[1] https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813

[2] https://github.com/pytorch/pytorch/blob/v2.8.0/torch/_inductor/compile_fx.py#L2281

[3] pytorch/pytorch#157580

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
tohtana pushed a commit to deepspeedai/DeepSpeed that referenced this pull request Oct 3, 2025
…phs (#7609)

# Motivation

PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization can also be used to recompute param aliases.
When partitioning a joint graph, we don't want to save for backward the
gathered parameters and values computed from them via aliasing ops, as
that essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching `choose_saved_values_set`, we
can achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 [2] and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

[1]
https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2]
https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] pytorch/pytorch#157580

# Proposal

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

# Preliminary Evaluation

Here's a summary of the tests using
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on
a 8x RTX 5090 node.

| Configuration | Base Time (ms) | Base Mem (GB) | Time with this PR
(ms) | Mem with this PR (GB) |

|---------------------|----------------|---------------|------------------------|-----------------------|
| eager + autocast | 551.92 | 12.07 | 571.24 | 9.96 |
| eager + bf16 | 419.87 | 9.47 | 445.76 | 7.30 |
| inductor + autocast | 546.97 | 12.84 | 570.09 | 13.04 |
| inductor + bf16 | 444.03 | 10.01 | 444.70 | 10.19 |

## Reduced memory with eager backend

The initial goal of this PR is to reduce peak memory usage when torch
autocast is enabled. That is achieved according to the first row of the
table, but in two different ways simultaneously.

1. Downcasted parameters during forward are throwed away and recomputed
(by the fused cast + allgather) in the backward pass.
2. Without this PR, `fast_free_schedule` will arange most allgather at
the beginning of the graph. That leads to a even higher peak during
forward, but is no longer seen with PR.
3. By diffing the graphs passed to `add_z3_gather_release`, I noticed
that recomputations selected by min-cut is slightly different (that test
script has activation checkpointing enabled for the LLM module). That
can also impact computation time and memory usage.

Here's the shape of memory usage before this PR with eager backend +
torch autocast. eager + BF16 shows similar shapes. Numbers reported in
the table are peak during forward. The peak memory usage during backend
reduces ~0.7GB in both cases.

<img width="1482" height="629" alt="image"
src="https://github.com/user-attachments/assets/7e7ec859-9a04-4ddd-ba37-c2d475a81058"
/>

After this PR:

<img width="1482" height="453" alt="image"
src="https://github.com/user-attachments/assets/f15c71b8-f823-4aa5-801a-a36188c5e866"
/>

## Similar memory with inductor backend

Unlike eager backend, the inductor backend uses similar memory with or
without this PR. The memory usage pattern is as follows, which requires
further analysis.

Before this PR:

<img width="1070" height="613" alt="image"
src="https://github.com/user-attachments/assets/317b9a58-d4ef-459f-ac7b-67ef2318a9de"
/>

After this PR:

<img width="911" height="536" alt="image"
src="https://github.com/user-attachments/assets/7e737a81-cf27-402c-aeea-dfe661043fc1"
/>

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
…phs (deepspeedai#7609)

# Motivation

PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization can also be used to recompute param aliases.
When partitioning a joint graph, we don't want to save for backward the
gathered parameters and values computed from them via aliasing ops, as
that essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching `choose_saved_values_set`, we
can achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 [2] and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

[1]
https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2]
https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] pytorch/pytorch#157580

# Proposal

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

# Preliminary Evaluation

Here's a summary of the tests using
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on
a 8x RTX 5090 node.

| Configuration | Base Time (ms) | Base Mem (GB) | Time with this PR
(ms) | Mem with this PR (GB) |

|---------------------|----------------|---------------|------------------------|-----------------------|
| eager + autocast | 551.92 | 12.07 | 571.24 | 9.96 |
| eager + bf16 | 419.87 | 9.47 | 445.76 | 7.30 |
| inductor + autocast | 546.97 | 12.84 | 570.09 | 13.04 |
| inductor + bf16 | 444.03 | 10.01 | 444.70 | 10.19 |

## Reduced memory with eager backend

The initial goal of this PR is to reduce peak memory usage when torch
autocast is enabled. That is achieved according to the first row of the
table, but in two different ways simultaneously.

1. Downcasted parameters during forward are throwed away and recomputed
(by the fused cast + allgather) in the backward pass.
2. Without this PR, `fast_free_schedule` will arange most allgather at
the beginning of the graph. That leads to a even higher peak during
forward, but is no longer seen with PR.
3. By diffing the graphs passed to `add_z3_gather_release`, I noticed
that recomputations selected by min-cut is slightly different (that test
script has activation checkpointing enabled for the LLM module). That
can also impact computation time and memory usage.

Here's the shape of memory usage before this PR with eager backend +
torch autocast. eager + BF16 shows similar shapes. Numbers reported in
the table are peak during forward. The peak memory usage during backend
reduces ~0.7GB in both cases.

<img width="1482" height="629" alt="image"
src="https://github.com/user-attachments/assets/7e7ec859-9a04-4ddd-ba37-c2d475a81058"
/>

After this PR:

<img width="1482" height="453" alt="image"
src="https://github.com/user-attachments/assets/f15c71b8-f823-4aa5-801a-a36188c5e866"
/>

## Similar memory with inductor backend

Unlike eager backend, the inductor backend uses similar memory with or
without this PR. The memory usage pattern is as follows, which requires
further analysis.

Before this PR:

<img width="1070" height="613" alt="image"
src="https://github.com/user-attachments/assets/317b9a58-d4ef-459f-ac7b-67ef2318a9de"
/>

After this PR:

<img width="911" height="536" alt="image"
src="https://github.com/user-attachments/assets/7e737a81-cf27-402c-aeea-dfe661043fc1"
/>

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
@github-actions github-actions bot deleted the gh/xuanzhang816/19/head branch October 6, 2025 02:10
Liangliang-Ma pushed a commit to Liangliang-Ma/DeepSpeed that referenced this pull request Oct 13, 2025
…phs (deepspeedai#7609)

# Motivation

PyTorch provides `min_cut_rematerialization_partition()` to partition a
joint graph while respecting recomputation annotation. That algorithm
forms a data-flow-like graph from the joint graph, adds to edges weights
from some recomputation-cost-related heuristics and applies the min-cut
algorithm to determine which nodes to recompute. Users can force
recomputation of a node by annotating its `node.meta["recompute"]` to
MUST_RECOMPUTE or PREFER_RECOMPUTE, as is implemented in [1].

While originally designed for activation checkpointing,
min_cut_rematerialization can also be used to recompute param aliases.
When partitioning a joint graph, we don't want to save for backward the
gathered parameters and values computed from them via aliasing ops, as
that essentially means the gathered parameter will be saved. Instead of
customizing the partitioner or patching `choose_saved_values_set`, we
can achieve that by annotating such nodes to be MUST_RECOMPUTE.

Both eager and inductor backends can use min_cut_rematerialization
easily. The eager backend can use min-cut by customizing the
partition_fn for `aot_module_simplified`, and is already using that for
graphs with activation checkpointing enabled. The inductor backend uses
that algorithm since torch 2.0.0 [2] and is still the default after the
inductor partitioner is made configurable a few weeks ago [3].

That approach also helps DeepCompile + torch autocast nicely. When
autocast is enabled, downcasted parameters are preferred to be
recomputed. It suffices to mark such casting nodes as must-recompute.

[1]
https://github.com/pytorch/pytorch/blob/main/torch/_functorch/partitioners.py#L1813
[2]
https://github.com/pytorch/pytorch/blob/v2.0.0/torch/_inductor/compile_fx.py#L459
[3] pytorch/pytorch#157580

# Proposal

Motivated by the flexibility and the requirement for optimizing
DeepCompile + autocast, I propose to switch to the min-cut-based
partitioner for both backends. This PR implements that switch, cleans up
dead code and also recomputes downcasted parameters in the backward.

# Preliminary Evaluation

Here's a summary of the tests using
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 on
a 8x RTX 5090 node.

| Configuration | Base Time (ms) | Base Mem (GB) | Time with this PR
(ms) | Mem with this PR (GB) |

|---------------------|----------------|---------------|------------------------|-----------------------|
| eager + autocast | 551.92 | 12.07 | 571.24 | 9.96 |
| eager + bf16 | 419.87 | 9.47 | 445.76 | 7.30 |
| inductor + autocast | 546.97 | 12.84 | 570.09 | 13.04 |
| inductor + bf16 | 444.03 | 10.01 | 444.70 | 10.19 |

## Reduced memory with eager backend

The initial goal of this PR is to reduce peak memory usage when torch
autocast is enabled. That is achieved according to the first row of the
table, but in two different ways simultaneously.

1. Downcasted parameters during forward are throwed away and recomputed
(by the fused cast + allgather) in the backward pass.
2. Without this PR, `fast_free_schedule` will arange most allgather at
the beginning of the graph. That leads to a even higher peak during
forward, but is no longer seen with PR.
3. By diffing the graphs passed to `add_z3_gather_release`, I noticed
that recomputations selected by min-cut is slightly different (that test
script has activation checkpointing enabled for the LLM module). That
can also impact computation time and memory usage.

Here's the shape of memory usage before this PR with eager backend +
torch autocast. eager + BF16 shows similar shapes. Numbers reported in
the table are peak during forward. The peak memory usage during backend
reduces ~0.7GB in both cases.

<img width="1482" height="629" alt="image"
src="https://github.com/user-attachments/assets/7e7ec859-9a04-4ddd-ba37-c2d475a81058"
/>

After this PR:

<img width="1482" height="453" alt="image"
src="https://github.com/user-attachments/assets/f15c71b8-f823-4aa5-801a-a36188c5e866"
/>

## Similar memory with inductor backend

Unlike eager backend, the inductor backend uses similar memory with or
without this PR. The memory usage pattern is as follows, which requires
further analysis.

Before this PR:

<img width="1070" height="613" alt="image"
src="https://github.com/user-attachments/assets/317b9a58-d4ef-459f-ac7b-67ef2318a9de"
/>

After this PR:

<img width="911" height="536" alt="image"
src="https://github.com/user-attachments/assets/7e737a81-cf27-402c-aeea-dfe661043fc1"
/>

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants