-
Notifications
You must be signed in to change notification settings - Fork 25.7k
For addmm and bmm, check if config.autotune_fallback_to_aten before using aten as a fallback. Also fix bmm cutlass backend #147148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…back [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147148
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit a0dd187 with merge base be0df96 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
torch/_inductor/kernel/bmm.py
Outdated
| len(choices) == 0 | ||
| and not use_aten_gemm_kernels() | ||
| and inductor_config.autotune_fallback_to_aten | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit - since we use the same condition multiple times, can we make a utility function for it? e.g. we could add one in mm_common.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, will add
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
| if should_fallback_to_aten(choices): | ||
| choices = [aten__int_mm.bind((mat1, mat2), layout)] | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenyang78 I think this try catch is not needed anymore. I plan on removing them in a separate PR, since they are higher risk imo
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
torch/_inductor/kernel/mm_common.py
Outdated
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def should_fallback_to_aten(choices) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the type of the input here? At least Sized or Sequence right?
| return X.get_size()[1] == W.get_size()[0] | ||
| X_size, W_size = X.get_size(), W.get_size() | ||
| if len(X_size) != len(W_size): | ||
| log.info("X and W have different ranks") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| log.info("X and W have different ranks") | |
| log.info("X and W have different ranks: %d, %d", len(X_size), len(W_size)) |
This would be effectively free if you cached the lens anyway?
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
| fallback_to_aten: bool = ( | ||
| len(choices) == 0 | ||
| and not use_aten_gemm_kernels() | ||
| and inductor_config.autotune_fallback_to_aten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need autotune_fallback_to_aten when max_autotune_gemm_backends exists ? isnt this duplicative ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need
autotune_fallback_to_atenwhenmax_autotune_gemm_backendsexists ? isnt this duplicative ?
I think the intent is to be super safe. Even if users specify only "CUTLASS" and CUTLASS fails, it will still keep things running. I think this safe logic predates the autotune_fallback_to_aten config.
But yeah it has been pretty painful for me when working on cutlass since bunch of stuff are failing silently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user wants to fallback, they should do max_autotune_gemm_backends="CUTLASS, ATEN". I dont think we should have two ways of doing the same exact thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison I agree with the idea. I can commit to removing autotune_fallback_to_aten all together and removing the silent fallback logic. But it will take a few PRs and a while. Does that sound good?
| if len(X_size) == 2: | ||
| return X_size[1] == W_size[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this check ? this should already be a precondition of lowering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this check ? this should already be a precondition of lowering.
The _shape_match in CUTLASS2x is a bit different for sparse. So I guess that was the intention.
Let me know what you think. I am fine with removing it here for cutlass 3x.
EDIT: changed it to always return True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison updated
| return X_size[1] == W_size[0] | ||
| if len(X_size) == 3: | ||
| # for bmm | ||
| return X_size[2] == W_size[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
torch/_inductor/kernel/mm.py
Outdated
| # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. | ||
| # See more details in https://github.com/pytorch/pytorch/pull/146293 | ||
| else r""" | ||
| if torch.version.hip is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintended format changes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintended format changes ?
yeah let me remove that..
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add test for the newly passing bmm case
there is an existing test, |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…en before using aten as a fallback. Also fix bmm cutlass backend " This PR also fixes BMM, which was silently failing for a while. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…sing aten as a fallback. Also fix bmm cutlass backend (pytorch#147148) This PR also fixes BMM, which was silently failing for a while. Pull Request resolved: pytorch#147148 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
This PR also fixes BMM, which was silently failing for a while.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov