KEMBAR78
Add fp16 support for gemm on CPU by CaoE · Pull Request #99498 · pytorch/pytorch · GitHub
Skip to content

Conversation

@CaoE
Copy link
Collaborator

@CaoE CaoE commented Apr 19, 2023

Stack from ghstack (oldest at bottom):

Testing

Native matmul vs. mkldnn matmul on SPR (with avx512_fp16 support)

single core:

Input Naïve impl / ms oneDNN / ms Speed up
M: 128, N: 128, K: 128, trans_a: False, trans_b: False 2010.387 64.700 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False 4027.116 107.780 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False 28685868.488 90663.008 316.401

56 cores:

Input Naïve impl / ms oneDNN / ms Speed up
M: 128, N: 128, K: 128, trans_a: False, trans_b: False 5.091 0.24 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True 5.224 0.23 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False 10.006 0.30 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False 29435.372 1.770 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True 31464.961 1.728 18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False 115035.849 7.990 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True 122981.023 7.725 15918.34
Batch: 768, M: 128, N: 64, K: 128 2032.523 0.705 2882.23

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @ngimel @desertfire

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99498

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 3ae0be7 with merge base 5dcee01 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: linalg_frontend release notes category label Apr 19, 2023
@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Apr 19, 2023
CaoE added a commit that referenced this pull request Apr 19, 2023
ghstack-source-id: a911370
Pull Request resolved: #99498
@CaoE CaoE marked this pull request as draft April 19, 2023 01:36
@mingfeima
Copy link
Collaborator

So once the failures are fixed, we shall provide some basic benchmark numbers.

Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Apr 24, 2023
ghstack-source-id: 8df672a
Pull Request resolved: #99498
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request May 4, 2023
ghstack-source-id: af13ff2
Pull Request resolved: #99498
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request May 4, 2023
ghstack-source-id: 763ac89
Pull Request resolved: #99498
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request May 5, 2023
ghstack-source-id: eed6a6c
Pull Request resolved: #99498
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request May 6, 2023
ghstack-source-id: bca5db3
Pull Request resolved: #99498
Not complete yet. WIP.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Sep 24, 2023
ghstack-source-id: 63c9606
Pull Request resolved: #99498


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Sep 25, 2023
ghstack-source-id: 2b527d1
Pull Request resolved: #99498
@CaoE CaoE requested a review from cpuhrsch September 25, 2023 07:05
@CaoE
Copy link
Collaborator Author

CaoE commented Sep 25, 2023

@cpuhrsch Could you please review this PR ? Thank you.

"mkldnn_linear: bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq");
} else if (self.scalar_type() == ScalarType::Half) {
TORCH_CHECK(mkldnn_fp16_device_check(),
"mkldnn_linear: fp16 path needs the cpu support avx_ne_convert or avx512_fp16");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is this a correct statement? ARM CPUs support half precision operations. Aren't mkldnn support those?

Copy link
Collaborator Author

@CaoE CaoE Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we are focused on x64. Not sure how complete the support is on ARM. It may be done in later PRs

Comment on lines 82 to 84
return (
at::globalContext().userEnabledMkldnn() &&
mkldnn_fp16_device_check());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove extraneous brackets (not sure about the rest of the formatting)

Suggested change
return (
at::globalContext().userEnabledMkldnn() &&
mkldnn_fp16_device_check());
return at::globalContext().userEnabledMkldnn() &&
mkldnn_fp16_device_check();

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised as suggested

TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, or AWS Graviton3");
} else {
TORCH_CHECK(mkldnn_fp16_device_check(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to check that it's called only for those two dtypes...

Suggested change
TORCH_CHECK(mkldnn_fp16_device_check(),
TORCH_DEBUG_ASSERT(mat1.scalar_type() == at::kHalf);
TORCH_CHECK(mkldnn_fp16_device_check(),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised as suggested

if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_linear: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
"mkldnn_linear: bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like mkldnn_matmul is supported on ARM devices. Are you sure about linter being the exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split bf16 check into bf16 check on x64 and on ARM.
I kept the previous message (adding a new isa avx_ne_convert), which also does not mention ARM. Do you think we should include ARM info. in all of such messages (include matmul, conv and deconv).
https://github.com/pytorch/pytorch/pull/99498/files#diff-dd15cec62ddb24d690b58e1902d8822347003437c9c4c4ae51f9c02e281fa33eR88

CaoE added a commit to CaoE/pytorch that referenced this pull request Sep 26, 2023
ghstack-source-id: 2b527d1
Pull Request resolved: pytorch#99498
CaoE added a commit to CaoE/pytorch that referenced this pull request Sep 26, 2023
ghstack-source-id: 2b527d1
Pull Request resolved: pytorch#99498


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]


### Testing

Native matmul vs. mkldnn matmul  on SPR (with avx512_fp16 support)

single core:

Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 2010.387 | 64.700 | 31.072
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 4027.116 | 107.780 | 37.364
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 28685868.488 | 90663.008 | 316.401


56 cores:
Input | Naïve impl   / ms | oneDNN /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 5.091 | 0.24 | 211.30
M: 128, N: 128, K: 128, trans_a: False, trans_b: True | 5.224 | 0.23 | 220.09
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 10.006 | 0.30 | 330.31
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 29435.372 | 1.770 | 1662.80
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True | 31464.961 | 1.728 |  18204.76
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 115035.849  | 7.990 | 14396.90
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 122981.023 |  7.725 | 15918.34
Batch: 768, M: 128, N: 64, K: 128  | 2032.523 | 0.705 | 2882.23


cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov desertfire

[ghstack-poisoned]
CaoE added a commit that referenced this pull request Sep 26, 2023
ghstack-source-id: dd8ff80
Pull Request resolved: #99498
@CaoE
Copy link
Collaborator Author

CaoE commented Sep 28, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/CaoE/19/head branch October 1, 2023 14:23
@lezcano lezcano changed the title add fp16 support for gemm Add fp16 support for gemm on CPU Jan 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/slow ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: half Related to float16 half-precision floats module: inductor open source release notes: linalg_frontend release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants