KEMBAR78
[inductor] enable mkldnn op weight pre-packing on aarch64 by snadampal · Pull Request #115037 · pytorch/pytorch · GitHub
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ static bool is_mkldnn_fp16_supported() {
return mkldnn_fp16_device_check();
}

constexpr bool is_mkldnn_acl_supported() {
return AT_MKLDNN_ACL_ENABLED();
}

TORCH_LIBRARY(mkldnn, m) {
m.class_<ConvOpContext>(TORCH_SELECTIVE_CLASS("ConvOpContext"))
.def_pickle(
Expand Down Expand Up @@ -69,6 +73,7 @@ TORCH_LIBRARY(mkldnn, m) {
"mkldnn::_reorder_mkldnn_rnn_layer_weight(Tensor weight0, Tensor weight1, int hidden_size, bool reverse, bool has_biases, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
}

TORCH_LIBRARY(mkldnn_prepacked, m) {
Expand Down
23 changes: 16 additions & 7 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,12 @@ def _is_packable_linear(match):
return False
batch_size = input_meta_value.shape[0]
is_bf16_weight = weight_meta_value.dtype == torch.bfloat16
# for fp32, mkl should be enabled and batch_size should not be a free symbol.
if not is_bf16_weight and (
(not torch._C.has_mkl) or has_free_symbols(batch_size)
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
if (
not is_bf16_weight
and not mkldnn._is_mkldnn_acl_supported()
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
):
return False
for meta_value in [input_meta_value, weight_meta_value]:
Expand Down Expand Up @@ -999,7 +1002,7 @@ def linear(match, *args, **kwargs):
batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size):
assert (
is_bf16_weight
is_bf16_weight or mkldnn._is_mkldnn_acl_supported()
), f"only bf16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
packed_weight_inputs = (
Expand All @@ -1010,15 +1013,15 @@ def linear(match, *args, **kwargs):
)
packed_weight_op = (
mkldnn._reorder_linear_weight
if is_bf16_weight
if (is_bf16_weight or mkldnn._is_mkldnn_acl_supported())
else torch.ops.mkl._mkl_reorder_linear_weight
)
packed_weight_node = graph.create_node(
"call_function", packed_weight_op, args=packed_weight_inputs
)

packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
if is_bf16_weight:
if is_bf16_weight or mkldnn._is_mkldnn_acl_supported():
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op = mkldnn._linear_pointwise.default
else:
Expand Down Expand Up @@ -1070,7 +1073,13 @@ def forward(self, x):

@functools.lru_cache(None)
def _mkldnn_fusion_init():
if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
# TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
# Otherwise even the matmul or innerproduct can not be accelerated with acl
if (
torch.backends.mkldnn.enabled
and torch.backends.mkldnn.is_available()
and not torch.ops.mkldnn._is_mkldnn_acl_supported()
):
_register_unary_fusion()
_register_inplace_fusion()
_register_binary_unary_fusion()
Expand Down