-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[MPS] Improve performance of max_pool3d #157875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157875
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 09330ba with merge base ee09928 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 263a0cc Pull-Request: pytorch#157875
|
To check how the changes from this PR and #157874 affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py Before these two PRs, I get the following median run times for various cases: After applying these two PRs, I get the following measurements: Every one of the cases is improved |
ghstack-source-id: ad2ac45 Pull-Request: pytorch#157875
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please avoid hardcoding types (unless you need it for some reason) but instead rely on compiler to infer it. This way, when you decide to change the type later, you don't have to worry about changing it throughout codebase
Also, use Metal's native uint3/uint4 data types to pass fixed sized arrays around (which can be accessed both by index or by .x/.y suffix)
| constant int32_t* input_sizes, | ||
| constant int32_t* input_strides, | ||
| thread int32_t (&work_pooling_dim_indices)[3], | ||
| constant int32_t* kernel_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered using vectorized types? They result in the faster code and kind of look more natural
| constant int32_t* kernel_size, | |
| constant uint3& kernel_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using vectorized types with the following diff, but it seemingly did not change the performance at all. I'm guessing that it compiles to something that is equivalent. I opted not to put it in the PR because I think it makes the code a little less readable. But if you still prefer it, I can put it in this PR.
Click to expand
diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal
index 18982559a34..53a4bf8f2a7 100644
--- a/aten/src/ATen/native/mps/kernels/Pooling.metal
+++ b/aten/src/ATen/native/mps/kernels/Pooling.metal
@@ -28,6 +28,20 @@ IterBounds<int32_t> get_input_iter_bounds(
return IterBounds<int32_t>{start, end};
}
+IterBounds<int3> get_input_iter_bounds(
+ constant int3& input_sizes,
+ thread int3& pooling_dim_indices,
+ constant int3& kernel_size,
+ constant int3& stride,
+ constant int3& padding,
+ constant int3& dilation) {
+ auto start = stride * pooling_dim_indices - padding;
+ auto end = min(start + kernel_size * dilation, input_sizes);
+ auto start_correction = dilation * ((-start - 1 + dilation) / dilation);
+ start += select(int3(0), start_correction, start < int3(0));
+ return IterBounds<int3>{start, end};
+}
+
// Iterates through all the input elements that this kernel needs to
// apply max to. Specialized for 3 pooling dimensions.
// TODO: Support any number of pooling dims
@@ -88,6 +102,55 @@ void max_pool_3d_input_iter(
}
}
+template <typename T>
+void max_pool_3d_input_iter(
+ constant T* input,
+ device T* output,
+ device int64_t* indices,
+ constant int3& input_sizes,
+ constant int3& input_strides,
+ thread int3& pooling_dim_indices,
+ constant int3& kernel_size,
+ constant int3& stride,
+ constant int3& padding,
+ constant int3& dilation,
+ bool return_indices) {
+ auto bounds = get_input_iter_bounds(
+ input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
+
+ T max_value = input
+ [input_strides.x * bounds.start.x + input_strides.y * bounds.start.y +
+ input_strides.z * bounds.start.z];
+ auto size12 = input_sizes.y * input_sizes.z;
+ auto max_index =
+ bounds.start.x * size12 + bounds.start.y * input_sizes.z + bounds.start.z;
+
+ for (auto i0 = bounds.start.x; i0 < bounds.end.x; i0 += dilation.x) {
+ auto offset0 = input_strides.x * i0;
+
+ for (auto i1 = bounds.start.y; i1 < bounds.end.y; i1 += dilation.y) {
+ auto offset1 = input_strides.y * i1;
+
+ for (auto i2 = bounds.start.z; i2 < bounds.end.z; i2 += dilation.z) {
+ auto offset2 = input_strides.z * i2;
+ auto input_value = input[offset0 + offset1 + offset2];
+ bool is_greater = input_value > max_value;
+
+ max_value = is_greater ? input_value : max_value;
+
+ if (return_indices) {
+ auto input_index = i0 * size12 + i1 * input_sizes.z + i2;
+ max_index = is_greater ? input_index : max_index;
+ }
+ }
+ }
+ }
+ *output = max_value;
+ if (return_indices) {
+ *indices = max_index;
+ }
+}
+
struct PoolOffsets {
int32_t output;
int32_t indices;
@@ -212,18 +275,41 @@ kernel void max_pool(
output += offsets.output;
indices += offsets.indices;
input += offsets.input_leading;
+ input_sizes += leading_dims;
+ input_strides += leading_dims;
+
+ constant int3& input_sizes_ = *reinterpret_cast<constant int3*>(input_sizes);
+ constant int3& input_strides_ = *reinterpret_cast<constant int3*>(input_strides);
+ thread int3& pooling_dim_indices_ = *reinterpret_cast<thread int3*>(pooling_dim_indices);
+ constant int3& kernel_size_ = *reinterpret_cast<constant int3*>(kernel_size);
+ constant int3& stride_ = *reinterpret_cast<constant int3*>(stride);
+ constant int3& padding_ = *reinterpret_cast<constant int3*>(padding);
+ constant int3& dilation_ = *reinterpret_cast<constant int3*>(dilation);
+
+ //max_pool_3d_input_iter<T>(
+ // input,
+ // output,
+ // indices,
+ // input_sizes,
+ // input_strides,
+ // pooling_dim_indices,
+ // kernel_size,
+ // stride,
+ // padding,
+ // dilation,
+ // return_indices);
max_pool_3d_input_iter<T>(
input,
output,
indices,
- input_sizes + leading_dims,
- input_strides + leading_dims,
- pooling_dim_indices,
- kernel_size,
- stride,
- padding,
- dilation,
+ input_sizes_,
+ input_strides_,
+ pooling_dim_indices_,
+ kernel_size_,
+ stride_,
+ padding_,
+ dilation_,
return_indices);
}
Do you mind mentioning it for posterity in PR description, and may be also integrate it into |
ghstack-source-id: 0dad99c Pull-Request: pytorch#157875
|
I've made a few more changes that improved performance, updated a bunch of places in the code to use |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
stride != 1#157876To check how the changes from this PR affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py.
Before this PR, I get this:
After this PR, I get this:
Every case is improved.