-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[quant][pt2] Fix custom dtype per channel weight in QAT #112612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112612
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7bccb52 with merge base 27e31ab ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| self.assertEqual(dq_dtype, torch.int32) | ||
|
|
||
|
|
||
| def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh OK you copied this, then it's probably fine
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar ghstack-source-id: 8998a5c Pull Request resolved: #112612
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar ghstack-source-id: c56ddff Pull Request resolved: #112612
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 1, 3, macos-m1-12) Details for Dev Infra teamRaised by workflow job |
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar Pull Request resolved: pytorch#112612 Approved by: https://github.com/jerryzh168
Summary: Previously we only copied over q/dq args for the per tensor case. This was because the qparams for `quantize_per_tensor` are literals while the qparams for `quantize_per_channel` are `get_attr` nodes (tensors), which disappear from the original nodes in the graph after subgraph rewriting. However, this is problematic because, in the per channel case, not all q/dq args are tensors. In particular, the args after the qparams (axis, qmin, qmax, dtype) are all literals. For these literal args we simply used the hardcoded ones (0, -127, 127, torch.int8 respectively), even if the user explicitly specified to use a different weight dtype. This commit fixes this by copying over these literal args for the per channel case as well. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype Reviewers: jerryzh168, kimishpatel Subscribers: jerryzh168, kimishpatel, supriyar ghstack-source-id: df92283 Pull Request resolved: pytorch/pytorch#112612
Stack from ghstack (oldest at bottom):
Summary: Previously we only copied over q/dq args for the per
tensor case. This was because the qparams for
quantize_per_tensorare literals while the qparams for
quantize_per_channelareget_attrnodes (tensors), which disappear from the originalnodes in the graph after subgraph rewriting.
However, this is problematic because, in the per channel case,
not all q/dq args are tensors. In particular, the args after
the qparams (axis, qmin, qmax, dtype) are all literals. For
these literal args we simply used the hardcoded ones
(0, -127, 127, torch.int8 respectively), even if the user
explicitly specified to use a different weight dtype. This
commit fixes this by copying over these literal args for the
per channel case as well.
Test Plan:
python test/test_quantization.py TestQuantizePT2EQAT.test_qat_per_channel_weight_custom_dtype
Reviewers: jerryzh168, kimishpatel
Subscribers: jerryzh168, kimishpatel, supriyar