KEMBAR78
[Inductor CUTLASS backend] Step 3: autotune_process, and CUDABenchmarkRequest by ipiszy · Pull Request #107901 · pytorch/pytorch · GitHub
Skip to content

Conversation

CUDABenchmarkRequest.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107901

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 73c7297 with merge base f9a250c (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ipiszy ipiszy changed the title [Inductor CUTLASS backend] Step 3: autotune_process, and [Inductor CUTLASS backend] Step 3: autotune_process, and CUDABenchmarkRequest Aug 24, 2023
ipiszy added 10 commits August 25, 2023 11:00
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

CUDABenchmarkRequest.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy ipiszy requested a review from jansel August 27, 2023 01:38
@ipiszy ipiszy marked this pull request as ready for review August 27, 2023 01:38
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor Author

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jansel @aakhundov @kadeng , ptal~

Comment on lines 339 to 342
assert (
self.workspace_size == 0
), "Autotune cache needs to be updated to support non-zero workspace_size!"
self.workspace = torch.empty(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

VarRanges = Dict[sympy.Expr, sympy.Expr]


def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do_bench also measures CPU-side overhead.

e.g. for an event sequence:
CUDA event begin -> CUDA launch kernel -> CUDA kernel execution -> CUDA event end
It measures time between CUDA event begin and CUDA event end, which contains "CUDA launch kernel" part. This part could take some CPU time. This is especially bad for CUTLASS kernels which rely on ctypes to invoke C++ functions from Python.

do_bench_using_profiling, on the other hand, relies on profiler to collect kernel device time.

Comment on lines +73 to +74
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from Triton bench. I think it makes sure that there is enough time spent to warm up the device (so that the GPU frequency is set to max). If we only pass a count, for small kernels the warm up time might not be enough, while for big kernels the warm up time could be too long. n_repeat also makes sure that the total time doesn't vary a lot for small and big kernels.

…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]

def test_do_bench(self):
res = do_bench(self._bench_fn)
log.error("do_bench result: %s", res)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use log.warning instead?


def test_do_bench_using_profiling(self):
res = do_bench_using_profiling(self._bench_fn)
log.error("do_bench_using_profiling result: %s", res)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert something about the return value

cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w)

def test_do_bench(self):
res = do_bench(self._bench_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should assert something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check to throw exception inside do_bench_using_profiling(), so the test below makes sure that no exception is thrown. I also use this two tests to compare latency collected by these two functions. We could skip the test_do_bench() test in CI though...

…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 9, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 9, 2023
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
…UDABenchmarkRequest"

This is the step 3 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 12, 2023
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.

Pull Request resolved: #107931
Approved by: https://github.com/aakhundov, https://github.com/jansel, https://github.com/kadeng
ghstack dependencies: #107802, #107847, #107901
pytorchmergebot pushed a commit that referenced this pull request Sep 12, 2023
This is the step 5 to add cutlass as an alternative inductor backend.

Feature request: #106991.

Pull Request resolved: #108015
Approved by: https://github.com/kadeng, https://github.com/jansel, https://github.com/aakhundov
ghstack dependencies: #107802, #107847, #107901, #107931
ipiszy added a commit that referenced this pull request Sep 15, 2023
In #107901, the CUDA event based
profiling is changed to profiler based profiling to avoid counting CPU-side
kernel launch overhead in final latency numbers. However, it turns out that
torch.profile() is significantly slower than CUDA event which affects model
compilation speed quite significantlly. This PR changes back to CUDA event
based profiling.

Follow-ups:
* Try CUDA event profiling with CUDAGraphs;
* Multi-GPU profiling;




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
ipiszy added a commit that referenced this pull request Sep 15, 2023
In #107901, the CUDA event based
profiling is changed to profiler based profiling to avoid counting CPU-side
kernel launch overhead in final latency numbers. However, it turns out that
torch.profile() is significantly slower than CUDA event which affects model
compilation speed quite significantlly. This PR changes back to CUDA event
based profiling.

Follow-ups:
* Try CUDA event profiling with CUDAGraphs;
* Multi-GPU profiling;




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/ipiszy@gmail.com/3/head branch September 16, 2023 14:23
pytorchmergebot pushed a commit that referenced this pull request Sep 17, 2023
In #107901, the CUDA event based
profiling is changed to profiler based profiling to avoid counting CPU-side
kernel launch overhead in final latency numbers. However, it turns out that
torch.profile() is significantly slower than CUDA event which affects model
compilation speed quite significantlly. This PR changes back to CUDA event
based profiling.

Follow-ups:
* Try CUDA event profiling with CUDAGraphs;
* Multi-GPU profiling;

Pull Request resolved: #109338
Approved by: https://github.com/frank-wei
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants