KEMBAR78
[Quant][Inductor][X86] add fusion pass for linear_dynamic_fp16 by Xia-Weiwen · Pull Request #141549 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Nov 26, 2024

Stack from ghstack (oldest at bottom):

Description
For linear_dynamic_fp16, we insert quantize and dequantize between x/w and linear to have the following pattern:

  x
  |
linear <- to_fp32 <- to_fp16 <- w

In Inductor, the pattern we finally see will be

fp32 activation
  |
(reshape)
  |
mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight
  |
(reshape)

Or

fp32 activation
  |
expand
  |
 bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight
  |
(add)

The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases.

Fuse the pattern with weight prepack, and we get

fp32 activation
  |
onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight

After freezing, the prepack op is gone.

Test plan

python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_dynamic_fp16

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Differential Revision: D66802159

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141549

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 67037d6 with merge base 795f28a (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review December 2, 2024 10:56
@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 Could you please review? Thanks.

)
def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs):
"""
Match the pattern:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is complicated, is there any way we can get the pattern from tracing a higher level pattern?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Do you have any examples as references? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example this is how we get pattern for qat:

match_pattern = _get_aten_graph_module_for_pattern(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. However, here we use the pattern matcher in Inductor and the pattern is described by things like CallFunction. So, looks like we cannot generate the pattern from tracing. And now we use hand-written patterns as defined in _generate_linear_dynamic_fp16_pattern above. So, looks like we cannot generate the pattern from tracing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168 Do you have more comments on this? Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchitintel Thanks for the pointer. Is it done by tracing or handwriting?

Copy link
Collaborator

@sanchitintel sanchitintel Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Individual patterns defined in fuse_attention.py with torch API are traced to produce serialized patterns of the kind you were alluding to (a CallFunction that may be nested). This approach helps avoid manually writing nested CallFunctions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to do the tracing? Thanks.

Copy link
Collaborator

@sanchitintel sanchitintel Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to this code-flow as an example -

def _sfdp_init():
for key, register_replacement_kwargs in _get_sfdp_patterns():
gen_register_replacement(key, **register_replacement_kwargs)

You could try playing around with this code, as it may help narrow down on what underlying Inductor API you could use for your use-case to trace nested CallFunctions corresponding to high-level patterns written with torch API.

I'm guessing it's probably this method, but I haven't verified -

pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround)

Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks quite complicated. I will try to understand it later. Thanks.

@jerryzh168
Copy link
Contributor

let me import to check internal CI

@jerryzh168
Copy link
Contributor

@jerryzh168 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 5, 2024
@jerryzh168
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 9, 2024
…elu (#141556)

**Description**
Fuse and prepack weight for `linear_dynamic_fp16` with post op relu. In Inductor, the pattern we see is
```
fp32 activation
  |
(reshape)
  |
mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight
  |
(reshape) <- relu
```
Or
```
fp32 activation
  |
expand
  |
 bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight
  |
(add) <- relu
```
The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases.

Fuse the pattern with weight prepack, and we get
```
fp32 activation
  |
onednn.linear_relu_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
```
After freezing, the prepack op is gone.

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_relu_dynamic_fp16
```

Pull Request resolved: #141556
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: #141549
pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
**Description**
For `linear_dynamic_fp16`, we insert `quantize` and `dequantize` between x/w and linear to have the following pattern:
```
  x
  |
linear <- to_fp32 <- to_fp16 <- w
```
In Inductor, the pattern we finally see will be
```
fp32 activation
  |
(reshape)
  |
mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight
  |
(reshape)
```
Or
```
fp32 activation
  |
expand
  |
 bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight
  |
(add)
```
The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases.

Fuse the pattern with weight prepack, and we get
```
fp32 activation
  |
onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
```
After freezing, the prepack op is gone.

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_dynamic_fp16
```

Differential Revision: [D66802159](https://our.internmc.facebook.com/intern/diff/D66802159)
Pull Request resolved: #141549
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…elu (pytorch#141556)

**Description**
Fuse and prepack weight for `linear_dynamic_fp16` with post op relu. In Inductor, the pattern we see is
```
fp32 activation
  |
(reshape)
  |
mm/addmm <- t <- to_fp32 <- tp_fp16 <- weight
  |
(reshape) <- relu
```
Or
```
fp32 activation
  |
expand
  |
 bmm <- expand <- t <- to_fp32 <- tp_fp16 <- weight
  |
(add) <- relu
```
The second pattern is for x.ndim > 2 and x is not contiguous. The first pattern is for other cases.

Fuse the pattern with weight prepack, and we get
```
fp32 activation
  |
onednn.linear_relu_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
```
After freezing, the prepack op is gone.

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_relu_dynamic_fp16
```

Pull Request resolved: pytorch#141556
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: pytorch#141549
@github-actions github-actions bot deleted the gh/Xia-Weiwen/21/head branch January 7, 2025 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants