-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
🚀 The feature, motivation and pitch
Intel-Extension-for-PyTorch (IPEX) offers an advanced int8-mixed-bf16 quantization path, which transforms the output of quantized Conv/GEMM operations into the BF16 data type if there is no subsequent quantized operator. This enhancement significantly improves the inference performance of models such as Bert/DistilBert, as the pointwise operators following GEMM will operate with BF16 data instead of FP32.
- Here is the example code for how use this feature with IPEX.
- Please note that this feature may result in accuracy loss for certain models. With IPEX, we have verified its accuracy in models such as Bert, DistilBert, stable diffusion, and some other LLM models. However, we have also observed accuracy issues in models like vision transformers.
- Similarly, we recently recive a feature request in BFloat16 datatype support in Quantization #111487.
Alternatives
We typically have two options to enable this feature.
Option 1: Use Autocast
Autocast is naturally employed for BF16 optimization in Inductor. Similarly, we can harness it for PT2E int8-mixed-bf16 features to generate a pattern like q -> dq -> float32_to_bfloat16 -> conv -> bfloat16_to_fp32 -> q -> dq.
to_bfloat16node before conv should be inserted when used Autocast +torch.compiletogether, since conv is in whitelist of Autocast.- As for inserting
bfloat16_to_fp32node after conv node, we need to extend the implementation ofby add these lines at beginning of this functionpytorch/torch/ao/quantization/fx/_decomposed.py
Lines 36 to 64 in 93a9b13
@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") def quantize_per_tensor( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype ) -> torch.Tensor: """ Affine quantization for the Tensor using the same quantization parameters to map from floating point to quantized values Args: input (torch.Tensor): original float32 Tensor scale (float): quantization parameter for affine quantization zero_point (int): quantization parameter for affine quantization quant_min (int): minimum quantized value for output Tensor quant_max (int): maximum quantized value for output Tensor dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor Returns: Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) inv_scale = 1.0 / scale return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
Here's an example code snippet:
exported_model = capture_pre_autograd_graph(
model,
example_inputs
)
# Create X86InductorQuantizer
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
# PT2E Quantization flow
prepared_model = prepare_pt2e(exported_model, quantizer)
# Calibration
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=enable_int8_mixed_bf16), torch.no_grad():
optimized_model = torch.compile(converted_model)
# Int8-Mixed-BF16 Inference
quant_output = optimized_model(images)
- Pros:
- Utilize the existing int8-mixed-fp32 quantizer and PT2E flow implementation.
- Make use of the existing Autocast operator list and mechanism.
- Cons:
- The Autocast mechanism will convert each input, including the convolution's bias, to BF16. However, for X86InductorQuantizer and the associated Inductor optimization, we anticipate that using float32 for the bias input may yield better accuracy.
Option 2: Add BFloat16 as a quantization type in PT2E Flow (in QuantizationSpec)
Alternatively, we can introduce BFloat16 as a quantization type in PT2E Flow (within QuantizationSpec).
- We may need to extend the Observer implementation to annotate its use for int8-mixed-bf16, depending on the quantization recipe.
- During the convert phase, we will examine the observer information to determine if it has been annotated with int8-mixed-bf16.
- If the input of a quantization node is in BFloat16 data type, an additional
to_floatnode will be inserted before the quantization node. - Following the dequantization node, an additional
to_bf16node will be inserted.
- If the input of a quantization node is in BFloat16 data type, an additional
exported_model = capture_pre_autograd_graph(
model,
example_inputs
)
# Create X86InductorQuantizer
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config(dtype=BFloat16))
# PT2E Quantization flow
prepared_model = prepare_pt2e(exported_model, quantizer)
# Calibration
converted_model = convert_pt2e(prepared_model)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
with torch.no_grad():
optimized_model = torch.compile(converted_model)
# Int8-Mixed-BF16 Inference
quant_output = optimized_model(images)
- Pros:
- We can achieve more flexibility with a customized implementation for int8-mixed-bf16 quantization, allowing us to overcome certain limitations in Autocast, such as bias conversion..
- Cons:
- Non-trivial changes may need in QuantizationSpec, Observer, Quantizer and PT2 Flow convert implementation.
We prefer option 1 as it requires fewer changes in the PT2E quantization flow and is clear and straightforward.
Additional context
Optimization Inside Inductor
-
Conv/GEMM
Here is the pattern after quantization flow we expect to see in Inductor.q -> dq -> float32_to_bfloat16 -> conv -> bfloat16_to_fp32 -> q -> dq- Step 1: In the weight prepack phase,
dq -> float32_to_bfloat16 -> convwill be matched at first to generate aqconv_bf16_outputnode withint8input dtype andbfloat16output dtype. - Step 2: Further more, we will check if
bfloat16_to_fp32 -> qpattern exists after thisqconv_bf16_outputnode. If so, we will further mergeqconv_bf16_output -> bfloat16_to_fp32 -> qinto aqconvnode with withint8input dtype andint8output dtype.
- Step 1: In the weight prepack phase,
-
Non-Conv/GEMM.
- Non-Conv/GEMM pattern will lowering in Inductor CPP Backend for Code Generation.
Enabling Plans
Here is some plans to follow up the option 1:
- Makes all
onednn.qconv1d_pointwise/linear_pointwiseoperators support BF16 output. - Remove the annotation of output at conv/linear in
X86InductorQuantizer. - Extend the decomposed quant to support bf16 input.
- Extend Weight prepack pattern matcher of
dequant -> to_bf16 -> conv/linear. - Extend QConv/Linear int8-mixed-bf16 output patterns matcher of
dequant -> to_bf16 -> conv/linear -> to_fp32 -> quant. - Extend Postop-passes pattern match of Conv/Linear ReLU/Add/Add_ReLU fusion with FP32/BF16 output.
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen