-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Quant][Inductor][X86] add fusion pass for linear_dynamic_fp16 #141549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141549
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 67037d6 with merge base 795f28a ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @jerryzh168 Could you please review? Thanks. |
| ) | ||
| def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): | ||
| """ | ||
| Match the pattern: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is complicated, is there any way we can get the pattern from tracing a higher level pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. Do you have any examples as references? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for example this is how we get pattern for qat:
| match_pattern = _get_aten_graph_module_for_pattern( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. However, here we use the pattern matcher in Inductor and the pattern is described by things like CallFunction. So, looks like we cannot generate the pattern from tracing. And now we use hand-written patterns as defined in _generate_linear_dynamic_fp16_pattern above. So, looks like we cannot generate the pattern from tracing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jerryzh168 Do you have more comments on this? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchitintel Thanks for the pointer. Is it done by tracing or handwriting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Individual patterns defined in fuse_attention.py with torch API are traced to produce serialized patterns of the kind you were alluding to (a CallFunction that may be nested). This approach helps avoid manually writing nested CallFunctions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to do the tracing? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to this code-flow as an example -
pytorch/torch/_inductor/fx_passes/fuse_attention.py
Lines 912 to 914 in b31d3b2
| def _sfdp_init(): | |
| for key, register_replacement_kwargs in _get_sfdp_patterns(): | |
| gen_register_replacement(key, **register_replacement_kwargs) |
You could try playing around with this code, as it may help narrow down on what underlying Inductor API you could use for your use-case to trace nested CallFunctions corresponding to high-level patterns written with torch API.
I'm guessing it's probably this method, but I haven't verified -
pytorch/torch/_inductor/pattern_matcher.py
Line 1461 in 653efe1
| pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) |
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks quite complicated. I will try to understand it later. Thanks.
|
let me import to check internal CI |
|
@jerryzh168 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…elu (#141556) **Description** Fuse and prepack weight for `linear_dynamic_fp16` with post op relu. In Inductor, the pattern we see is ``` fp32 activation | (reshape) | mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight | (reshape) <- relu ``` Or ``` fp32 activation | expand | bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight | (add) <- relu ``` The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases. Fuse the pattern with weight prepack, and we get ``` fp32 activation | onednn.linear_relu_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight ``` After freezing, the prepack op is gone. **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_relu_dynamic_fp16 ``` Pull Request resolved: #141556 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: #141549
**Description** For `linear_dynamic_fp16`, we insert `quantize` and `dequantize` between x/w and linear to have the following pattern: ``` x | linear <- to_fp32 <- to_fp16 <- w ``` In Inductor, the pattern we finally see will be ``` fp32 activation | (reshape) | mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight | (reshape) ``` Or ``` fp32 activation | expand | bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight | (add) ``` The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases. Fuse the pattern with weight prepack, and we get ``` fp32 activation | onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight ``` After freezing, the prepack op is gone. **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_dynamic_fp16 ``` Differential Revision: [D66802159](https://our.internmc.facebook.com/intern/diff/D66802159) Pull Request resolved: #141549 Approved by: https://github.com/jgong5, https://github.com/jerryzh168
…elu (pytorch#141556) **Description** Fuse and prepack weight for `linear_dynamic_fp16` with post op relu. In Inductor, the pattern we see is ``` fp32 activation | (reshape) | mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight | (reshape) <- relu ``` Or ``` fp32 activation | expand | bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight | (add) <- relu ``` The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases. Fuse the pattern with weight prepack, and we get ``` fp32 activation | onednn.linear_relu_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight ``` After freezing, the prepack op is gone. **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_relu_dynamic_fp16 ``` Pull Request resolved: pytorch#141556 Approved by: https://github.com/jgong5, https://github.com/jerryzh168 ghstack dependencies: pytorch#141549
Stack from ghstack (oldest at bottom):
Description
For
linear_dynamic_fp16, we insertquantizeanddequantizebetween x/w and linear to have the following pattern:In Inductor, the pattern we finally see will be
Or
The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases.
Fuse the pattern with weight prepack, and we get
After freezing, the prepack op is gone.
Test plan
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov
Differential Revision: D66802159