KEMBAR78
For addmm and bmm, check if config.autotune_fallback_to_aten before using aten as a fallback. Also fix bmm cutlass backend by henrylhtsang · Pull Request #147148 · pytorch/pytorch · GitHub
Skip to content

Conversation

@henrylhtsang
Copy link
Contributor

@henrylhtsang henrylhtsang commented Feb 13, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147148

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit a0dd187 with merge base be0df96 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

henrylhtsang added a commit that referenced this pull request Feb 13, 2025
@henrylhtsang henrylhtsang added the topic: not user facing topic category label Feb 13, 2025
@henrylhtsang henrylhtsang changed the title check if config.autotune_fallback_to_aten before using aten as a fallback For addmm and bmm, check if config.autotune_fallback_to_aten before using aten as a fallback. Also fix bmm cutlass backend Feb 13, 2025
@henrylhtsang henrylhtsang marked this pull request as draft February 14, 2025 00:06
@henrylhtsang henrylhtsang marked this pull request as ready for review February 14, 2025 00:14
…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
len(choices) == 0
and not use_aten_gemm_kernels()
and inductor_config.autotune_fallback_to_aten
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit - since we use the same condition multiple times, can we make a utility function for it? e.g. we could add one in mm_common.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, will add

…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
henrylhtsang added a commit that referenced this pull request Feb 14, 2025
if should_fallback_to_aten(choices):
choices = [aten__int_mm.bind((mat1, mat2), layout)]

try:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenyang78 I think this try catch is not needed anymore. I plan on removing them in a separate PR, since they are higher risk imo

…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
log = logging.getLogger(__name__)


def should_fallback_to_aten(choices) -> bool:
Copy link
Collaborator

@Skylion007 Skylion007 Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the type of the input here? At least Sized or Sequence right?

return X.get_size()[1] == W.get_size()[0]
X_size, W_size = X.get_size(), W.get_size()
if len(X_size) != len(W_size):
log.info("X and W have different ranks")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.info("X and W have different ranks")
log.info("X and W have different ranks: %d, %d", len(X_size), len(W_size))

This would be effectively free if you cached the lens anyway?

…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
fallback_to_aten: bool = (
len(choices) == 0
and not use_aten_gemm_kernels()
and inductor_config.autotune_fallback_to_aten
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need autotune_fallback_to_aten when max_autotune_gemm_backends exists ? isnt this duplicative ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need autotune_fallback_to_aten when max_autotune_gemm_backends exists ? isnt this duplicative ?

I think the intent is to be super safe. Even if users specify only "CUTLASS" and CUTLASS fails, it will still keep things running. I think this safe logic predates the autotune_fallback_to_aten config.

But yeah it has been pretty painful for me when working on cutlass since bunch of stuff are failing silently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user wants to fallback, they should do max_autotune_gemm_backends="CUTLASS, ATEN". I dont think we should have two ways of doing the same exact thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison I agree with the idea. I can commit to removing autotune_fallback_to_aten all together and removing the silent fallback logic. But it will take a few PRs and a while. Does that sound good?

Comment on lines 1168 to 1169
if len(X_size) == 2:
return X_size[1] == W_size[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check ? this should already be a precondition of lowering.

Copy link
Contributor Author

@henrylhtsang henrylhtsang Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check ? this should already be a precondition of lowering.

The _shape_match in CUTLASS2x is a bit different for sparse. So I guess that was the intention.

Let me know what you think. I am fine with removing it here for cutlass 3x.

EDIT: changed it to always return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison updated

return X_size[1] == W_size[0]
if len(X_size) == 3:
# for bmm
return X_size[2] == W_size[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

# The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking.
# See more details in https://github.com/pytorch/pytorch/pull/146293
else r"""
if torch.version.hip is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended format changes ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended format changes ?

yeah let me remove that..

…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
henrylhtsang added a commit that referenced this pull request Feb 19, 2025
@eellison eellison self-requested a review February 20, 2025 18:05
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test for the newly passing bmm case

@henrylhtsang
Copy link
Contributor Author

please add test for the newly passing bmm case

there is an existing test, test_max_autotune_cutlass_backend_simple_bmm

@henrylhtsang
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 20, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

…en before using aten as a fallback. Also fix bmm cutlass backend "


This PR also fixes BMM, which was silently failing for a while.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
henrylhtsang added a commit that referenced this pull request Feb 20, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@henrylhtsang
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

This PR (#147148) was merged in 76ce194 but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

pull bot pushed a commit to A-Archives-and-Forks/pytorch that referenced this pull request Feb 21, 2025
…sing aten as a fallback. Also fix bmm cutlass backend (pytorch#147148)

This PR also fixes BMM, which was silently failing for a while.

Pull Request resolved: pytorch#147148
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the gh/henrylhtsang/9/head branch March 27, 2025 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants