-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add fp16 support for gemm on CPU #99498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99498
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 3ae0be7 with merge base 5dcee01 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
So once the failures are fixed, we shall provide some basic benchmark numbers. |
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
Not complete yet. WIP. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
|
@cpuhrsch Could you please review this PR ? Thank you. |
| "mkldnn_linear: bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq"); | ||
| } else if (self.scalar_type() == ScalarType::Half) { | ||
| TORCH_CHECK(mkldnn_fp16_device_check(), | ||
| "mkldnn_linear: fp16 path needs the cpu support avx_ne_convert or avx512_fp16"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, is this a correct statement? ARM CPUs support half precision operations. Aren't mkldnn support those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we are focused on x64. Not sure how complete the support is on ARM. It may be done in later PRs
| return ( | ||
| at::globalContext().userEnabledMkldnn() && | ||
| mkldnn_fp16_device_check()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove extraneous brackets (not sure about the rest of the formatting)
| return ( | |
| at::globalContext().userEnabledMkldnn() && | |
| mkldnn_fp16_device_check()); | |
| return at::globalContext().userEnabledMkldnn() && | |
| mkldnn_fp16_device_check(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised as suggested
| TORCH_CHECK(mkldnn_bf16_device_check(), | ||
| "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, or AWS Graviton3"); | ||
| } else { | ||
| TORCH_CHECK(mkldnn_fp16_device_check(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to check that it's called only for those two dtypes...
| TORCH_CHECK(mkldnn_fp16_device_check(), | |
| TORCH_DEBUG_ASSERT(mat1.scalar_type() == at::kHalf); | |
| TORCH_CHECK(mkldnn_fp16_device_check(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised as suggested
| if (self.scalar_type() == ScalarType::BFloat16) { | ||
| TORCH_CHECK(mkldnn_bf16_device_check(), | ||
| "mkldnn_linear: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); | ||
| "mkldnn_linear: bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like mkldnn_matmul is supported on ARM devices. Are you sure about linter being the exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split bf16 check into bf16 check on x64 and on ARM.
I kept the previous message (adding a new isa avx_ne_convert), which also does not mention ARM. Do you think we should include ARM info. in all of such messages (include matmul, conv and deconv).
https://github.com/pytorch/pytorch/pull/99498/files#diff-dd15cec62ddb24d690b58e1902d8822347003437c9c4c4ae51f9c02e281fa33eR88
ghstack-source-id: 2b527d1 Pull Request resolved: pytorch#99498
ghstack-source-id: 2b527d1 Pull Request resolved: pytorch#99498
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
### Testing Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support) single core: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401 56 cores: Input | Naïve impl / ms | oneDNN / ms | Speed up -- | -- | -- | -- M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30 M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09 M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31 M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80 M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 | 18204.76 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849 | 7.990 | 14396.90 M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 | 7.725 | 15918.34 Batch: 768, M: 128, N: 64, K: 128 | 2032.523 | 0.705 | 2882.23 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Testing
Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support)
single core:
56 cores:
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @ngimel @desertfire