KEMBAR78
[ROCm] Support new AMD triton stream pipeliner by jataylo · Pull Request #139881 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jataylo
Copy link
Collaborator

@jataylo jataylo commented Nov 6, 2024

@jataylo jataylo added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm labels Nov 6, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139881

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 31093c7 with merge base d031d1b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jataylo jataylo added ciflow/inductor-rocm Trigger "inductor" config CI on ROCm and removed module: rocm AMD GPU support for Pytorch module: inductor ciflow/inductor labels Nov 6, 2024
@jataylo jataylo marked this pull request as draft November 6, 2024 13:06
@jataylo jataylo marked this pull request as ready for review November 6, 2024 13:37
@davidberard98
Copy link
Contributor

@pytorchbot rebase -s

@davidberard98
Copy link
Contributor

^^ I want to see if the "0 active drivers" failures go away with a rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 7, 2024

We're green on CI now @bertmaher @davidberard98 @aakhundov if you agree this solution is suitable.

[
{"config": (32, 32, 16, 1, 2), "cond": True},
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
{"config": (32, 32, 128, 2, 4), "cond": True},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: this config (and the two below) didn't work before but are going to work now? Also with the old Triton version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We disabled these a long time ago as we were seeing runtime errors when we would throw a shmem OOM error. The autotune implementation seems safe enough now to not completely error out if we hit a config that throws oom issues for us. The CI is green here on the older triton still so I think we're safe here.

from .runtime.triton_helpers import get_backend_options

options = get_backend_options()
return options.get("num_stages", 2 if torch.version.hip else None)
Copy link
Contributor

@aakhundov aakhundov Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default 2 here? Will this work with the old Triton version (pre-3.2) on ROCm?

Copy link
Collaborator Author

@jataylo jataylo Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We grab the default num_stages from the AMD backend via HIPOptions, pre 3.2 Triton will return num_stages=0 still as expected. https://github.com/triton-lang/triton/pull/4845/files#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72

I just provided a default case for safety here, maybe some future proofing if the dict changes (pretty unlikely num_stages will renamed though...) 2 is the new recommended for gemms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explanation. This helper is not specifically marked / named for ROCm. So, if folks end up using this for NV, I assume it will still work as expected (i.e., will options contain num_stages on NV, too)? Also, should we @lru_cache this, as the value shouldn't change through one run?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the same should work on NV here too

CUDAOptions
https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/compiler.py#L98

Adding lru_cache makes sense I'll make that change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could return 3 as default on NV side rather than return None too.

@bertmaher
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 8, 2024
threads = torch.get_num_threads()
return threads

@functools.lru_cache(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, linter needs one more blank line above this (sigh)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Fixes pytorch#139182

In Triton 3.2 num_stages=0 will be deprecated with Triton's AMD backend. Let's query default num_stages from the relevant triton backend

Pull Request resolved: pytorch#139881
Approved by: https://github.com/bertmaher
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source release notes: inductor release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] [Upstream Triton] num_stages=0 deprecation with stream pipeliner v2

6 participants