-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ROCm] Support new AMD triton stream pipeliner #139881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139881
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 31093c7 with merge base d031d1b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot rebase -s |
|
^^ I want to see if the "0 active drivers" failures go away with a rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
We're green on CI now @bertmaher @davidberard98 @aakhundov if you agree this solution is suitable. |
| [ | ||
| {"config": (32, 32, 16, 1, 2), "cond": True}, | ||
| {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, | ||
| {"config": (32, 32, 128, 2, 4), "cond": True}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: this config (and the two below) didn't work before but are going to work now? Also with the old Triton version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We disabled these a long time ago as we were seeing runtime errors when we would throw a shmem OOM error. The autotune implementation seems safe enough now to not completely error out if we hit a config that throws oom issues for us. The CI is green here on the older triton still so I think we're safe here.
torch/_inductor/utils.py
Outdated
| from .runtime.triton_helpers import get_backend_options | ||
|
|
||
| options = get_backend_options() | ||
| return options.get("num_stages", 2 if torch.version.hip else None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the default 2 here? Will this work with the old Triton version (pre-3.2) on ROCm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We grab the default num_stages from the AMD backend via HIPOptions, pre 3.2 Triton will return num_stages=0 still as expected. https://github.com/triton-lang/triton/pull/4845/files#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72
I just provided a default case for safety here, maybe some future proofing if the dict changes (pretty unlikely num_stages will renamed though...) 2 is the new recommended for gemms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for the explanation. This helper is not specifically marked / named for ROCm. So, if folks end up using this for NV, I assume it will still work as expected (i.e., will options contain num_stages on NV, too)? Also, should we @lru_cache this, as the value shouldn't change through one run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the same should work on NV here too
CUDAOptions
https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/compiler.py#L98
Adding lru_cache makes sense I'll make that change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could return 3 as default on NV side rather than return None too.
|
@pytorchbot merge |
| threads = torch.get_num_threads() | ||
| return threads | ||
|
|
||
| @functools.lru_cache(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, linter needs one more blank line above this (sigh)
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#139182 In Triton 3.2 num_stages=0 will be deprecated with Triton's AMD backend. Let's query default num_stages from the relevant triton backend Pull Request resolved: pytorch#139881 Approved by: https://github.com/bertmaher
Fixes #139182
In Triton 3.2 num_stages=0 will be deprecated with Triton's AMD backend. Let's query default num_stages from the relevant triton backend
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov