-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Enable TF32 as fp32 internal precision for matmul/linear/conv #157520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157520
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (12 Unrelated Failures)As of commit 342dc1f with merge base d7e1b8b ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
float32_matmul_precision = at::Float32MatmulPrecision::HIGH; | ||
setFloat32Precision("cuda", "matmul", "tf32"); | ||
setFloat32Precision("mkldnn", "matmul", "ieee"); | ||
setFloat32Precision("mkldnn", "matmul", "tf32"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't quite understand what this change means, "ieee" to "tf32"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the description of #125888, it says,
We provide 3 fp32 compute precision can be set:
"ieee": Not allowed to use any other internal computation data types .
"tf32": Allowed to use tf32 as internal computation data types.
"bf16": Allowed to use bf16 as internal computation data types.
"none": Precision's are not set. Can be override by its father node.
"HIGHEST, HIGH, MEDIUM"
is a legacy representation, means ieee
, tf32
, and bf16
.
So without this PR, mkldnn backend only supports ieee
, bf16
and none
. If set to HIGH
, tf32
is not supported in mkldnn, use ieee
instead. With this PR, tf32
is supported in mkldnn, so we can use tf32
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just some suggestions for simplifying the code.
mat2.numel() != 0 && | ||
checksize(mat1, mat2)); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file has multiple functions that have similar usage:
use_mkldnn_bf16_matmul
, use_mkldnn_fp16_matmul
, use_mkldnn_bf32_matmul
and use_mkldnn_tf32_matmul
can we templatize it to simplify the code?
template <typename T>
bool use_mkldnn_matmul();
#if defined(__aarch64__)
bool use_mkldnn_matmul<at::BFloat16>();
#endif
const Tensor& result) { | ||
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result)); | ||
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result) || use_mkldnn_tf32_matmul(mat1, mat2, result)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can template use_mkldnn_matmul<>, then you can do something like:
AT_DISPATCH_FLOATING_AND2(kBFloat16, kHalf, ..., [&] {
return use_mkldnn_matmul<scalar_t>(...);
});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Please take a look again.
aten/src/ATen/Context.cpp
Outdated
const std::map<std::string, std::vector<std::string>> _fp32_precisions = { | ||
{"generic", {{"ieee", "tf32", "bf16", "none"}}}, | ||
{"mkldnn", {{"ieee", "bf16", "none"}}}, | ||
{"mkldnn", {{"ieee", "bf16", "tf32", "none"}}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the ordering different from "generic"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific reason, I will change it same as generic
.
Hi @jansel , could you please also take a look at this PR? Similar as the previous PRs to enable BF32, we can easily extend the API to support TF32 for |
Enable TF32 as fp32 internal precision for Linear Enable TF32 as fp32 internal precision for conv ghstack-source-id: 5365a8c Pull Request resolved: pytorch#157520
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failures?
The failures appear after I rebased yesterday, and it is instresting that the failures are same as #158209, which only updates warning log. Let me try to find the point. Update: The failures are caused by #150762. See #150762 (comment) |
Hi @jansel , CI all passes after rebase. Could you please take a look again? Thanks! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Description
This PR is to enable TF32 as fp32 internal precision for matmul/linear/conv in
mkldnn backend
. Since we have refined fp32 precision API in #125888, we can easily extend the API to support TF32 formkldnn backend
.Related kernel update and UTs update are done. And the wrapper
bf32_on_and _off
is updated toreduced_f32_on_and_off
, and it can run tests 3 times, one is reduced_f32 OFF, the other two are reduced_f32 ON (includingbf32 ON
andtf32 ON
).Stack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov