KEMBAR78
Generate patterns in fp16 and fp32 by eellison · Pull Request #109142 · pytorch/pytorch · GitHub
Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Sep 12, 2023

Stack from ghstack (oldest at bottom):

aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109142

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ad284ca with merge base 518308a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Sep 12, 2023
ghstack-source-id: 9773914
Pull Request resolved: #109142
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison eellison added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 13, 2023
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

cc @drisspg, after rebasing past fv2, meta stride propagation is incorrect.

FAIL: test_sdpa_rewriter_1_cuda (__main__.SDPAPatternRewriterCudaTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/scratch/eellison/work/pytorch/torch/testing/_internal/common_utils.py", line 1283, in wrapper
    return fn(*args, **kwargs)
  File "test/inductor/test_fused_attention.py", line 115, in _test_sdpa_rewriter_1
    self._check_common(dot_prod_attention, dtype=dtype, atol=0.001, rtol=rtol)
  File "test/inductor/test_fused_attention.py", line 70, in _check_common
    result2, (source_code,) = run_and_get_code(
  File "/scratch/eellison/work/pytorch/torch/_inductor/utils.py", line 829, in run_and_get_code
    result = fn(*args, **kwargs)
  File "/scratch/eellison/work/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "test/inductor/test_fused_attention.py", line 100, in dot_prod_attention
    def dot_prod_attention(
  File "/scratch/eellison/work/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    return fn(*args, **kwargs)
  File "/scratch/eellison/work/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/scratch/eellison/work/pytorch/torch/_functorch/aot_autograd.py", line 3905, in forward
    return compiled_fn(full_args)
  File "/scratch/eellison/work/pytorch/torch/_functorch/aot_autograd.py", line 1482, in g
    return f(*args)
  File "/scratch/eellison/work/pytorch/torch/_functorch/aot_autograd.py", line 2533, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/scratch/eellison/work/pytorch/torch/_functorch/aot_autograd.py", line 1506, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/scratch/eellison/work/pytorch/torch/_functorch/aot_autograd.py", line 1594, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/scratch/eellison/work/pytorch/torch/_inductor/codecache.py", line 396, in __call__
    return self.get_current_callable()(inputs)
  File "/scratch/eellison/work/pytorch/torch/_inductor/compile_fx.py", line 616, in run
    return model(new_inputs)
  File "/scratch/eellison/work/pytorch/torch/_inductor/codecache.py", line 423, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_eellison/2m/c2mkma5h4mw3jpc2ixso7f6ccno2vqe3oqdzr3j7pkvtnc55lh6o.py", line 39, in call
    assert_size_stride(buf1, (4, 2, 16, 32), (1024, 512, 32, 1))
AssertionError: expected size 2==2, stride 32==512 at dim=1

@drisspg
Copy link
Contributor

drisspg commented Sep 15, 2023

@eellison Do you have a repro/test I can use to fix this?

@eellison
Copy link
Contributor Author

import torch
from torch._dynamo.testing import rand_strided
arg0_1 = rand_strided((4, 2, 16, 32), (1024, 512, 32, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((4, 2, 16, 32), (1024, 512, 32, 1), device='cuda:0', dtype=torch.float16)
arg2_1 = rand_strided((4, 2, 16, 32), (1024, 512, 32, 1), device='cuda:0', dtype=torch.float16)
with torch._subclasses.fake_utils.CrossRefFakeMode():
    torch.ops.aten._scaled_dot_product_flash_attention(arg0_1, arg1_1, arg2_1, scale=0.17677669529663687)

@drisspg
Copy link
Contributor

drisspg commented Sep 15, 2023

#109346 I am confused this isn't reproing for me locally Figured out offline with Elias

aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@eellison
Copy link
Contributor Author

@pytorchbot revert -m "test failures on main"

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@eellison
Copy link
Contributor Author

@pytorchbot revert -m MESSAGE -c landrace

@eellison eellison reopened this Sep 19, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 19, 2023
This reverts commit 14994cc.

Reverted #109142 on behalf of https://github.com/eellison due to MESSAGE ([comment](#109142 (comment)))
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
Adds a 3d pattern that improves perf of HF Whisper from 1.3 -> 4.1. We could be matching more generally on 3d, but i'll leave that for another pr.

Thanks to @drisspg for helping me write the pattern.

Pull Request resolved: #109156
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917, #109142
pytorchmergebot pushed a commit that referenced this pull request Sep 20, 2023
The pretty print is faster and more concise because it memoizes objects.

Pull Request resolved: #109066
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917, #109142, #109156
@facebook-github-bot facebook-github-bot deleted the gh/eellison/539/head branch September 23, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants