KEMBAR78
Removing sdpa conv layout constraint by eellison · Pull Request #112045 · pytorch/pytorch · GitHub
Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,30 +332,6 @@ def decide_layout_opt(gm) -> bool:
log.debug("Skip layout opt because all convolution channels are too small")
return False

# aten._scaled_dot_product_flash_attention requires the last stride of query/key/value
# to be 1. Check https://gist.github.com/shunting314/fa6eeab2aad8d1265c4d5e50b560d94f
# for more details.
#
# When a model contains aten._scaled_dot_product_flash_attention and we enable layout optimization,
# the op may get channels last input and fail. Example include: twins_pcpvt_base, xcit_large_24_p8_224
# for _scaled_dot_product_flash_attention and xcit_large_24_p8_224 for _scaled_dot_product_efficient_attention.
#
# We disable layout optimization if a model contains aten._scaled_dot_product_flash_attention.
#
# An alternative is to do necessary layout conversion to make sure aten._scaled_dot_product_flash_attention's
# inputs have the layout needed. But that seems to have worse perf than disabing the layout opt.
# TODO(shunting) revisit if we can still apply layout optimization to models containing sdpa while
# bringing perf gains.
for n in gm.graph.nodes:
if n.target in (
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
):
log.debug(
"Skip layout optimization because sdpa (scaled dot product attention) is found"
)
return False

return True

def find_nodes_prefer_channels_last(self):
Expand Down