-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[cudagraphs] Fix issue in collecting static_input_idxs #152287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152287
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8ccbcc7 with merge base 8e2e06b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
): | ||
if dynamic: | ||
self.assertEqual(static_input_idxs, [0, 1, 2, 3, 4]) | ||
self.assertEqual(static_input_idxs, [2, 3, 4]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test looks strictly more correct now than previously. For sanity, this is the signature of the AOT graph in this test:
def forward(self, arg0_1: "Sym(s25)", arg1_1: "f32[s25][1]cpu", arg2_1: "f32[s25][1]cpu", arg3_1: "f32[s25][1]cpu", arg4_1: "Sym(s25)", arg5_1: "f32[s25][1]cpu"):
Where indices [2,3] correspond to the two static tensor inputs that mapped to the static TwoTensor subclass.
One thing that is wrong in this test though is that:
(1) in the dynamic shapes variant of this test, we have extra SymInt graph args that correspond to the symbolic sizes of the subclass
(2) we are marking those inputs as static indices as well, which is happening here: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/subclass_utils.py#L308
This seems wrong. It might turn out not cause too many problems, if inductor has logic to properly filter out SymInts from the "static input indices" list later (given that integers have no memory address and get burned into cudagraphs anyway). But we should probably fix it either way. cc @mlazos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, I can take a look at the issue.
torch/_functorch/aot_autograd.py
Outdated
"_dynamo_static_input_type", None | ||
): | ||
static_inputs_log.debug( | ||
"Adding static input pos %s for source %s", pos, source_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we update this log call as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…)" This reverts commit 75a5646. Reverted #152287 on behalf of https://github.com/wdvr due to causing ao failures - discussed with author ([comment](#152287 (comment)))
@anijain2305 thank you for the quick bugfix! I've applied the changes in |
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
related to #152275 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
# add on non param inputs | ||
preserved_arg_indices.extend(range(len(flat_params), len(params))) | ||
# is this necessary ? | ||
fw_metadata.static_input_indices = static_indices_new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are the new changes to update the static_input_indices list under freezing. As @eellison pointed out, it might be a good idea to stash this info on the graph placeholders directly in the future, so we don't need to worry about updating this list after ever calling convention change in inductor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks these should be easy to merge in 2.7
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "stuck merge" |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot cherry-pick --onto release/2.7 -c critical |
related to #152275 Pull Request resolved: #152287 Approved by: https://github.com/bdhirsh, https://github.com/eellison Co-authored-by: Brian Hirsh <hirsheybar@fb.com> (cherry picked from commit 4a63cab)
Cherry picking #152287The cherry pick PR is at #152768 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
[cudagraphs] Fix issue in collecting static_input_idxs (#152287) related to #152275 Pull Request resolved: #152287 Approved by: https://github.com/bdhirsh, https://github.com/eellison (cherry picked from commit 4a63cab) Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
Stack from ghstack (oldest at bottom):
related to #152275
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov