-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Always disable ShardingPropagation cache if compiling #156868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156868
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit c043984 with merge base 900fba4 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Nice - a few comments: (1) I think there are a few other places where we rely on (2) we can also kill the (3) finally, let's add some tests. I mentioned that you can use the (3a) we should try to add a test that would have failed without your PR. let me know if you want to brainstorm on ways to exercise this path (3b) I would also like to generally beef up the amount of dynamic shape testing that we do for DTensor. Doesn't have to be part of this PR. we can talk more about it later - one option is just to manually add more teseting in that file. Another oculd be to hook into existing DTensor OpInfo tests if there are any, let's chat with DTensor folks about it |
72577cd to
0a43a51
Compare
c57e01c to
c189f1a
Compare
|
After clarifying details, I tested several approaches for handling the cache skip. I tested
Because the change also impacts compiled autograd (and potentially custom backends?), I tested
Add
Cat
Compiled Autograd
Since we would like to prioritize eager runtime (adding compile time is not a big deal), try/except makes the most sense since it is the only time effective method, only adding significant overhead if entering the except clause, that doesn't cause errors. The results are noisy, so in general I expect no to very little additional overhead from this try/except vs the existing conditional statement. You can find a hacky implementation of the benchmark code here azahed98@e76885f |
@bdhirsh There is one place that uses has_symints (in redistribute), but it recalculates it based on DTensorSpec directly (rather than searching for it), so I left it as is. If I'm understanding correctly, this is fine? |
|
After follow up, the implementation is changed to add a global flag to compiled autograd |
078b87b to
b24d943
Compare
| log_pt2_compile_event=True, | ||
| ) | ||
| self.compile_context.__exit__(None, None, None) | ||
| in_compiled_autograd_initial_trace = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please also add it to torch._dynamo.compiled_autograd.reset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address the pending comments
7e5418b to
c83693a
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m="This is breaking tests in torchtitan https://github.com/pytorch/torchtitan/actions/runs/16656486277/job/47142656711" |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot revert -m "This is breaking tests in torchtitan https://github.com/pytorch/torchtitan/actions/runs/16656486277/job/47142656711" -c nosignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 156868 failedReason: Command Details for Dev Infra teamRaised by workflow job |
#159671) Fixes #159601 Unfortunately #156868 introduced a couple regressions (see #159590 and #159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init. Pull Request resolved: #159671 Approved by: https://github.com/xmfan
…ble keys (#160798) Fixes #159590 This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue #159601. This adds little to no overhead in eager ([see past benchmarks](#156868 (comment))). This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: #160798 Approved by: https://github.com/bdhirsh
…on cache (pytorch#159671) Fixes pytorch#159601 Unfortunately pytorch#156868 introduced a couple regressions (see pytorch#159590 and pytorch#159601). This reverts the commit while I am working on a permanent fix. This means the `in_compiled_autograd_initial_trace` global flag will be removed and the `_are_we_tracing()` will instead be replaced with the symint preprocessing step during sharding prop post init. Pull Request resolved: pytorch#159671 Approved by: https://github.com/xmfan
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
Fixes #151106
Addresses issue (2) in #152963 for the DTensor sharding propagation cache being brittle under compile. The existing
_are_we_tracingfromdistributed._functional_collectives, which mostly determines if currently tracing based on Fake Tensor dispatch mode, is reused here.Test Plan:
There are already tests for DTensor + Compile with dynamic shape (test_dtensor_dynamic,
test_dynamo_dtensor_from_local_dynamic_shapes) that cover the change.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @tianyu-l @XilunWu @xmfan