KEMBAR78
[onnx] support attn_mask fp16 type by rui-ren · Pull Request #110306 · pytorch/pytorch · GitHub
Skip to content

Conversation

@rui-ren
Copy link
Contributor

@rui-ren rui-ren commented Sep 29, 2023

When users define customized attention mask using dtype=torch.float16, e.g.

from torch.nn import functional as F

float_min = torch.finfo(torch.float16).min

attention_mask_fp16 = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(torch.float16)

attn_output = F.scaled_dot_product_attention(
                 query_layer_, key_layer_, value_layer_, attention_mask_fp16, 0.0, is_causal=False
 )

the onnx graph cannot be exported.

When q, k ,v have the fp16 type, we can support this attn_mask to be fp16 type, by adding

elif (
        _type_utils.JitScalarType.from_value(attn_mask)
        == _type_utils.JitScalarType.FLOAT
        in (_type_utils.JitScalarType.FLOAT, _type_utils.JitScalarType.HALF)

This can export .onnx graph.

Fixes #109336

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110306

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6149727 with merge base d04b35e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 29, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Sep 29, 2023
@titaiwangms
Copy link
Collaborator

titaiwangms commented Sep 29, 2023

This looks good!
Please sign CLA and lint the code (https://github.com/pytorch/pytorch/wiki/lintrunner).

@titaiwangms titaiwangms self-assigned this Sep 29, 2023
@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: improvements topic category labels Sep 29, 2023
@rui-ren rui-ren marked this pull request as ready for review September 29, 2023 23:59
Copy link
Collaborator

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@titaiwangms titaiwangms added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 1, 2023
@titaiwangms
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@rui-ren rui-ren deleted the rui-ren/onnx-support-attn-mask-fp16-dtype branch October 1, 2023 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[onnx exporter] Falcon-7b onnx graph exporter issue from huggingface source

4 participants