KEMBAR78
Always disable ShardingPropagation cache if compiling by azahed98 · Pull Request #156868 · pytorch/pytorch · GitHub
Skip to content

Conversation

@azahed98
Copy link
Contributor

@azahed98 azahed98 commented Jun 25, 2025

Fixes #151106

Addresses issue (2) in #152963 for the DTensor sharding propagation cache being brittle under compile. The existing _are_we_tracing from distributed._functional_collectives, which mostly determines if currently tracing based on Fake Tensor dispatch mode, is reused here.

Test Plan:
There are already tests for DTensor + Compile with dynamic shape (test_dtensor_dynamic,
test_dynamo_dtensor_from_local_dynamic_shapes) that cover the change.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @tianyu-l @XilunWu @xmfan

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156868

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit c043984 with merge base 900fba4 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jun 25, 2025
@azahed98 azahed98 added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category labels Jun 25, 2025
@azahed98 azahed98 requested a review from bdhirsh June 25, 2025 19:35
@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 25, 2025

Nice - a few comments:

(1) I think there are a few other places where we rely on has_symints throughout dtensor code, not just that one call site?

(2) we can also kill the has_symint calculation that DTensor does, now that we no longer need it

(3) finally, let's add some tests. I mentioned that you can use the test_dtensor_compile.py tests for inspiration, but in particular:

(3a) we should try to add a test that would have failed without your PR. let me know if you want to brainstorm on ways to exercise this path

(3b) I would also like to generally beef up the amount of dynamic shape testing that we do for DTensor. Doesn't have to be part of this PR. we can talk more about it later - one option is just to manually add more teseting in that file. Another oculd be to hook into existing DTensor OpInfo tests if there are any, let's chat with DTensor folks about it

@azahed98 azahed98 force-pushed the feat/dtensor_compile_cache branch 2 times, most recently from 72577cd to 0a43a51 Compare June 26, 2025 18:31
@azahed98 azahed98 force-pushed the feat/dtensor_compile_cache branch from c57e01c to c189f1a Compare July 8, 2025 08:01
@azahed98
Copy link
Contributor Author

azahed98 commented Jul 8, 2025

After clarifying details, I tested several approaches for handling the cache skip. I tested
I tested

  1. has_symints -This is the existing approach that just checks top level for symints. This is calculated in the post init of the op info instead of at time of sharding propagation calculator.
  2. _are_we_tracing - This replaces the has_symints check with a call to an _are_we_tracing function that already exists. This checks the dispatcher mode and assumes fake tensor dispatch will mean compile globally disables the cache if so.
  3. Try/except fallback - This catches a TypeError from the LRUCache attempt and will fallback to the non cached version
  4. Hashable conditional - This checks at sharding propagation time if the argument is hashable. This is done deeply, but requires a try/except for unrecognized leaves types as Hashable is not trustable.

Because the change also impacts compiled autograd (and potentially custom backends?), I tested

  1. Add
  2. Cat (new test that fails with the previous conditional)
  3. Compiled autograd of a simple model (existing test that fills with just changing has_symints to _are_we_tracing)
    The following benchmarks are in seconds, and do not include operation time, only the cache logic overhead (including checking the cache).

Add

Approach Avg Time Eager Avg Time Compile Avg Diff (E - C)
has_symints 1.51904e-05 3.17651e-06 1.20139e-05
_are_we_tracing 2.69765e-05 1.05024e-05 1.64741e-05
try/except 1.56521e-05 5.89978e-05 -4.33457e-05
If deep hashable 2.24493e-05 3.96081e-05 -1.71588e-05
_are_we_tracing + try/except 2.81515e-05 1.06873e-05 1.74642e-05
_are_we_tracing + if hashable 3.58346e-05 1.16638e-05 2.41709e-05

Cat

Approach Avg Time Eager Avg Time Compile Avg Diff (E - C)
has_symints 1.66608e-05
_are_we_tracing 2.12326e-05 9.20193e-06 1.20307e-05
try/except 1.29369e-05 3.99154e-05 -2.69785e-05
If deep hashable 2.01795e-05 8.59986e-05 -6.58191e-05
_are_we_tracing + try/except 2.08283e-05 9.31555e-06 1.15127e-05
_are_we_tracing + if hashable 2.90011e-05 1.14933e-05 1.75078e-05

Compiled Autograd

Approach Avg Time Eager Avg Time Compile Avg Diff (E - C)
has_symints 1.17937e-05 2.25911e-05 -1.07974e-05
_are_we_tracing
try/except 1.74108e-05 2.11901e-05 -3.77927e-06
If deep hashable 2.44827e-05 3.29761e-05 -8.49338e-06
_are_we_tracing + try/except 2.75316e-05 3.73415e-06 2.37974e-05
_are_we_tracing + if hashable 3.50689e-05 4.09760e-06 3.09713e-05

Since we would like to prioritize eager runtime (adding compile time is not a big deal), try/except makes the most sense since it is the only time effective method, only adding significant overhead if entering the except clause, that doesn't cause errors. The results are noisy, so in general I expect no to very little additional overhead from this try/except vs the existing conditional statement.

You can find a hacky implementation of the benchmark code here azahed98@e76885f

@azahed98
Copy link
Contributor Author

azahed98 commented Jul 8, 2025

(1) I think there are a few other places where we rely on has_symints throughout dtensor code, not just that one call site?

@bdhirsh There is one place that uses has_symints (in redistribute), but it recalculates it based on DTensorSpec directly (rather than searching for it), so I left it as is. If I'm understanding correctly, this is fine?

@azahed98
Copy link
Contributor Author

After follow up, the implementation is changed to add a global flag to compiled autograd torch._dynamo.compiled_autograd.in_compiled_autograd_initial_trace. This handles the compiled autograd failure with minimal additional overhead (similar to just _are_we_tracing() above).

@azahed98 azahed98 marked this pull request as ready for review July 10, 2025 22:07
@azahed98 azahed98 force-pushed the feat/dtensor_compile_cache branch 2 times, most recently from 078b87b to b24d943 Compare July 14, 2025 18:35
@azahed98 azahed98 requested review from wconstab and xmfan July 14, 2025 21:58
log_pt2_compile_event=True,
)
self.compile_context.__exit__(None, None, None)
in_compiled_autograd_initial_trace = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add it to torch._dynamo.compiled_autograd.reset

Copy link
Member

@xmfan xmfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the pending comments

@azahed98 azahed98 force-pushed the feat/dtensor_compile_cache branch from 7e5418b to c83693a Compare July 16, 2025 16:55
@azahed98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@azahed98
Copy link
Contributor Author

azahed98 commented Aug 1, 2025

@pytorchbot revert -m="This is breaking tests in torchtitan https://github.com/pytorch/torchtitan/actions/runs/16656486277/job/47142656711"

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@azahed98
Copy link
Contributor Author

azahed98 commented Aug 1, 2025

@pytorchbot revert -m "This is breaking tests in torchtitan https://github.com/pytorch/torchtitan/actions/runs/16656486277/job/47142656711" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 156868 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit f6d138807f138868de0397936e2bee482c1fb987 returned non-zero exit code 1

Auto-merging torch/distributed/_functional_collectives.py
Auto-merging torch/distributed/tensor/_op_schema.py
Auto-merging torch/distributed/tensor/_sharding_prop.py
CONFLICT (content): Merge conflict in torch/distributed/tensor/_sharding_prop.py
error: could not revert f6d138807f1... Always disable ShardingPropagation cache if compiling (#156868)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

azahed98 added a commit to azahed98/pytorch that referenced this pull request Aug 1, 2025
azahed98 added a commit to azahed98/pytorch that referenced this pull request Aug 1, 2025
pytorchmergebot pushed a commit that referenced this pull request Aug 4, 2025
#159671)

Fixes #159601

Unfortunately #156868 introduced a couple regressions (see #159590 and #159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init.

Pull Request resolved: #159671
Approved by: https://github.com/xmfan
pytorchmergebot pushed a commit that referenced this pull request Sep 9, 2025
…ble keys (#160798)

Fixes #159590

This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](#156868 (comment))).

This also handles cases such as #159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: #160798
Approved by: https://github.com/bdhirsh
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…on cache (pytorch#159671)

Fixes pytorch#159601

Unfortunately pytorch#156868 introduced a couple regressions (see pytorch#159590 and pytorch#159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init.

Pull Request resolved: pytorch#159671
Approved by: https://github.com/xmfan
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-distributed ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: compiled autograd compiled_autograd module: dtensor distributed tensor tag module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

distributed/tensor/_op_schema has_symints does not check args_schema

5 participants