KEMBAR78
[Quant][PT2E][X86] annotate and convert for linear_dynamic_fp16 by Xia-Weiwen · Pull Request #141480 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Nov 25, 2024

Stack from ghstack (oldest at bottom):

Annotate linear node for linear_dynamic_fp16 with X86InductorQuantizer
After convert_pt2e, the pattern will be

  x
  |
linear <- to_fp32 <- to_fp16 <- w

Test plan

pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_dynamic_fp16

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @ezyang @SherlockNoMad @EikanWang @wenzhe-nrv

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141480

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1c827ef with merge base 2398e75 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Nov 25, 2024
@Xia-Weiwen Xia-Weiwen marked this pull request as draft November 25, 2024 07:44
[ghstack-poisoned]
Xia-Weiwen added a commit that referenced this pull request Nov 25, 2024
@Xia-Weiwen Xia-Weiwen requested review from jgong5 and leslie-fang-intel and removed request for jgong5 and leslie-fang-intel November 26, 2024 01:12
[ghstack-poisoned]
[ghstack-poisoned]
@Xia-Weiwen Xia-Weiwen added the intel This tag is for PR from Intel label Nov 26, 2024
@Xia-Weiwen Xia-Weiwen changed the title [Quant][PT2E] annotate and convert for linear_dynamic_fp16 [Quant][PT2E][CPU] annotate and convert for linear_dynamic_fp16 Nov 26, 2024
@Xia-Weiwen Xia-Weiwen changed the title [Quant][PT2E][CPU] annotate and convert for linear_dynamic_fp16 [Quant][PT2E][X86] annotate and convert for linear_dynamic_fp16 Nov 26, 2024
graph.erase_node(node)
elif dtype == torch.float16:
raise NotImplementedError("decomposed to float16 op not implemented yet")
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems for dtype==torch.float16, torch.ops.quantized_decomposed.quantize_per_tensor.default has same semantic as to(dtype=torch.float16). Then why not just use to(dtype=torch.float16)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because to will be constant-folded and the pattern will be hard to match.

Copy link
Contributor

@jerryzh168 jerryzh168 Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also feel using to might be better, that way we won't have multiple ops doing the same thing, wondering what is needed to use to here

Copy link
Collaborator Author

@Xia-Weiwen Xia-Weiwen Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, to will be folded in Inductor. Here is the implementation with to:
https://github.com/pytorch/pytorch/blob/76ad4bb890b66098672e1f0349e83e28f2d6d85e/torch/ao/quantization/fx/convert.py#L345C1-L357C1
The patter I got in Inductor is x/w -> linear, no to is seen.
If we insert quant/dequant, we will see dequant op in Inductor, pattern: x / (w -> dequant) -> linear

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can create a variant of: torch.ops.prims.convert_element_type.default

let's say torch.ops.prims.convert_element_type.no_fuse and add the op to this list

torch.ops.quantized_decomposed.dequantize_per_channel.default,
so it's not fused by inductor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Thanks for the suggestion. Do you know how to add a new variant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the same as adding a new custom op:

@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we are actually adding a new op in the quantized_decomposed namespace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you can do that, I think ideally in torch.ops.prims but not sure the process there, I can check

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I have updated the PR. Please take a look. Thanks.

warnings.warn(
"Mixed dynamic and static quantization config is not supported."
)
need_skip = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to remove this code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we always set is_dynamic=False for linear_dynamic_fp16 but it needs to work with dynamic quantization of other ops. Besides, I didn't see an issue if users use mixed static/dynamic quantization for different ops.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I feels this flag:

is not needed any more, could we just remove it? cc @yiliu30

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have reverted this change because this is not an issue for now

Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same comment as @leslie-fang-intel . It seems simpler to just insert to(dtype=torch.half) instead of "quantize" in the graph.

@Xia-Weiwen
Copy link
Collaborator Author

I have the same comment as @leslie-fang-intel . It seems simpler to just insert to(dtype=torch.half) instead of "quantize" in the graph.

Unfortunately, to will be constant-folded and make the pattern difficult to match. So, we need to insert quant & dequant on the graph for pattern match.

@Xia-Weiwen Xia-Weiwen requested a review from jgong5 November 27, 2024 02:42
[ghstack-poisoned]
[ghstack-poisoned]
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review November 27, 2024 08:43
[ghstack-poisoned]


quantized_decomposed_lib.define(
"convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to add this op to

torch.ops.quantized_decomposed.dequantize_per_channel.default,
so it's not constant folded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added this in the next PR. Thanks

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

[ghstack-poisoned]
@jerryzh168
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 29, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

GeorgeWigley pushed a commit to graphcore/pytorch-fork that referenced this pull request Nov 29, 2024
…rch#141480)

Annotate linear node for `linear_dynamic_fp16` with `X86InductorQuantizer`
After `convert_pt2e`, the pattern will be
```
  x
  |
linear <- to_fp32 <- to_fp16 <- w
```

**Test plan**
```
pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_dynamic_fp16
```

Pull Request resolved: pytorch#141480
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…rch#141480)

Annotate linear node for `linear_dynamic_fp16` with `X86InductorQuantizer`
After `convert_pt2e`, the pattern will be
```
  x
  |
linear <- to_fp32 <- to_fp16 <- w
```

**Test plan**
```
pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_linear_dynamic_fp16
```

Pull Request resolved: pytorch#141480
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
@github-actions github-actions bot deleted the gh/Xia-Weiwen/20/head branch December 30, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fx intel This tag is for PR from Intel Merged open source release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants