KEMBAR78
Vectorized operation on quantized tensors returns wrong values (different rounding) · Issue #107030 · pytorch/pytorch · GitHub
Skip to content

Vectorized operation on quantized tensors returns wrong values (different rounding) #107030

@Flamefire

Description

@Flamefire

🐛 Describe the bug

The following code fails:

import numpy as np
import torch

X = torch.from_numpy(np.full(64+1, 514., dtype=np.float32))
(scale, zero_point, torch_type) = (1028.02, 255, torch.quint8)

assert X.is_contiguous(memory_format=torch.contiguous_format)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
                               dtype=torch_type)

f_min, f_max = 0.0, 1.0
q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max
output_scale = (f_max - f_min) / (q_max - q_min + 1.0)

qY = torch.ops.quantized.sigmoid(qX, output_scale=output_scale, output_zero_point=0)
print(qY)
assert qY[0] == qY[-1]

In particular the first 64 values are "0.5039" while the remainder are "0.5000". This happens for any remainder not fitting into chunks of 64 values.

Found by reducing an example of a failing test in test_quantization:

======================================================================
FAIL: test_sigmoid (quantization.core.test_quantized_op.TestQuantizedOps)
----------------------------------------------------------------------
Traceback (most recent call last):
<snip>
AssertionError: Quantized tensor-likes are not close!

Mismatched elements: 63 / 75 (84.0%)
Greatest absolute difference: 0.00390625 at index (0, 0, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.0078125 at index (0, 0, 1) (up to 1.3e-06 allowed) : sigmoid - quantized.sigmoid failed: (tensor([[[0.0000, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039]],

        [[0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039]],

        [[0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]], size=(3, 5, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.00390625, zero_point=0) vs. tensor([[[0.0000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]], size=(3, 5, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.00390625, zero_point=0))
Falsifying example: test_sigmoid(
    X=(array([[[-261630.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]],

            [[    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]],

            [[    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]]],
           dtype=float32), (1028.0156862745098, 255, torch.quint8)),
    self=<quantization.core.test_quantized_op.TestQuantizedOps testMethod=test_sigmoid>,
)

----------------------------------------------------------------------
Ran 942 tests in 656.469s

FAILED (failures=2, errors=1, skipped=72)

This seems to happen for all PyTorch versions so far and does not depend on the host CPU. I reproduced this even on ppc64le.

Versions

PyTorch version: 2.0.1+cu117
Is debug build: False

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 11.3.0
Clang version: Could not collect
CMake version: version 3.27.1
Libc version: glibc-2.17

Python version: 3.10.4 (main, Oct 6 2022, 14:14:40) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.11.1.el7.x86_64-x86_64-with-glibc2.17

Is XNNPACK available: True

CPU:
Architecture: x86_64
Model name: AMD EPYC 7352 24-Core Processor
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc art rep_good nopl nonstop_tsc extd_apicid aperfmperf eagerfpu pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_l2 cpb cat_l3 cdp_l3 hw_pstate sme retpoline_amd ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.0.1
[pip3] triton==2.0.0
[conda] Could not collect

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

Metadata

Metadata

Assignees

Labels

oncall: quantizationQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions