Source code for quantizeml.models.check_quantization

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

import tensorflow as tf

from ..layers import WeightQuantizer, OutputQuantizer, QFloatRecorder
from .record import record_quantization_variables


[docs]def check_quantization(model): """Checks the specified model quantization. It looks for errors that can be fixed in the quantization configuration: - inaccurate weight scales quantization, - saturation in integer operations. Args: model (keras.Model): the model to check Returns: list(str): the quantization issues that were detected """ # First record the model variables to store qscales in output quantizers record_quantization_variables(model) messages = [] for layer in model.layers: for _, attr in layer.__dict__.items(): if isinstance(attr, OutputQuantizer): quantizer = attr range_max = quantizer.get_weights()[0] if tf.reduce_all(range_max == 1): messages.append(f"{layer.name}/{quantizer.name} is not calibrated.") continue # Evaluate the maximum left shift for this quantizer maximum_shift = quantizer.value_bits - layer.buffer_bitwidth # There might be a saturation in integer operations if we reached the maximum shift output_shift = quantizer.shift.value if tf.reduce_any(output_shift <= maximum_shift): messages.append(f"Possible saturation detected in {layer.name}: " "try to reduce weights bits and/or scale bits.") # Check inaccurate weight scales quantization. qscales = getattr(quantizer, 'qscales', None) if qscales: # Retrieve the original scales. They are located in the WeightQuantizer of the # layer if there is one, or in all WeightQuantizer if they are several (eg. # SeparableConvolution) ideal_scales = None for _, attr in layer.__dict__.items(): if (isinstance(attr, WeightQuantizer) and isinstance(attr.qweights, QFloatRecorder)): if ideal_scales is None: ideal_scales = attr.qweights.value.scales else: ideal_scales *= attr.qweights.value.scales if ideal_scales is not None: err = tf.abs(ideal_scales - qscales.value.to_float()) / tf.abs(ideal_scales) mean_err = tf.reduce_mean(err) if mean_err > 5e-2: message = f"Scales quantization relative error is high in " \ f"{layer.name}/{quantizer.name}: {mean_err:.4f}." if quantizer._axis == "per-tensor": message += "Use a per-axis quantizer and/or increase scales bits." else: message += "Try increasing scales bits." messages.append(message) return messages