KEMBAR78
[ROCm] Bug fix for flex attention configs avoiding ROCm path by jataylo · Pull Request #140270 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jataylo
Copy link
Collaborator

@jataylo jataylo commented Nov 11, 2024

@jataylo jataylo added ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm labels Nov 11, 2024
@jataylo jataylo requested a review from bertmaher November 11, 2024 15:07
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140270

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 1b16483 with merge base 62eea62 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jataylo jataylo added topic: not user facing topic category and removed open source labels Nov 11, 2024
@jataylo jataylo requested a review from drisspg November 11, 2024 20:36
@jataylo jataylo added the rocm priority high priority ROCm PRs from performance or other aspects label Nov 11, 2024
@jataylo
Copy link
Collaborator Author

jataylo commented Nov 11, 2024

Adding priority label to enable testing for triton 3.2/pytorch 2.6

default_config = None

if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
if head_dim <= 256 and torch.version.hip:
Copy link
Contributor

@drisspg drisspg Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be cleaner to instead create a _get_config_rocm and we branch at call site, even if there is overlap

TBH this flow is already hard to grok

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need the conditionalisation on head_dim at least

Copy link
Collaborator Author

@jataylo jataylo Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any specific recommendations @drisspg, would like to get this bugfix in asap for triton 3.2 debugging, but the config selection in this file definitely does need refactoring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should copy and paste the entire function and create two versions, 1 for rocm and for cuda 1. It feel to me although you need branching on head_dms that the branching should likely be different given the reduced smem constraints and increased warp size on AMD gpus.

I think lets just make two funcs and copy and paste as much as you need, until the dust settles some more

Copy link
Contributor

@bertmaher bertmaher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this has addressed @drisspg's concerns so I'm good to merge

@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 12, 2024

There are failures associated, likely due to the new structure on NV side... let me sanity check and when we get a green signal I'll merge. cc: @bertmaher

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 14, 2024

Rebasing failure seems unrelated

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 14, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased flex-fix onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout flex-fix && git pull --rebase)

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 14, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 15, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…#140270)

Fixes pytorch#139755 pytorch#139621

Follow up fix to pytorch#139883 which made the bulk of the changes required but a logic error resulted in ROCm still using h100 configurations.

Pull Request resolved: pytorch#140270
Approved by: https://github.com/bertmaher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source rocm priority high priority ROCm PRs from performance or other aspects topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] [Triton 3.2] OOM shmem issues on Inductor tests with new SW pipelining

5 participants