Source code for quantizeml.layers.quantizers.weight_quantizer

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["WeightQuantizer"]

import tensorflow as tf

from ...tensors import QTensor, QFloat, FixedPoint
from ..recorders import QFloatRecorder, FixedPointRecorder
from .quantizers import Quantizer


[docs]@tf.keras.utils.register_keras_serializable() class WeightQuantizer(Quantizer): """A uniform quantizer that converts a float Tensor to a QFloat representation. In order, the WeightQuantizer: - evaluates the scales required to align the values on optimal ranges for FixedPoint quantization, - quantizes the rescaled Tensor as a FixedPoint and returns a QFloat. Args: bitwidth (int, optional): the quantization bitwidth, defaults to 4. signed (bool, optional): whether the quantizer expects signed values or unsigned. Defaults to True. axis (int, optional): the quantization range is a scalar (None) or a vector corresponding to the given axis. Defaults to -1. fp_quantizer (bool, optional): True to enable FixedPoint quantization, QFloat otherwise. Defaults to False. """ def __init__(self, bitwidth=4, signed=True, axis=-1, fp_quantizer=False, **kwargs): super().__init__(bitwidth, signed, **kwargs) self.axis = axis self.fp_quantizer = fp_quantizer if self.fp_quantizer: self.qweights = FixedPointRecorder() else: self.qweights = QFloatRecorder()
[docs] def build(self, input_shape): """Build the layer. Args: input_shape (list): the shape of input tensor. """ super().build(input_shape) # Convert axis to a list of int if self.axis is not None and len(input_shape) > 1: axis_list = list(range(len(input_shape))) positive_axis = self.axis if self.axis >= 0 else len(axis_list) + self.axis if positive_axis < 0 or positive_axis > len(axis_list) - 1: raise ValueError(f"Axis {self.axis} is not valid in {self.name}.") axis_list.pop(positive_axis) self._axis = axis_list else: self._axis = None
[docs] def call(self, inputs): """Quantize the float inputs The quantization is done in two steps: 1. Compute the quantization ranges, 2. Quantize the inputs. Args: inputs (tf.Tensor): the inputs tensor. Returns: :obj:`QFloat`: the quantized tensor. """ if isinstance(inputs, QTensor): raise ValueError( f"{type(inputs)} input is not supported. WeightQuantizer only accepts float" " inputs.") # Compute the quantization ranges from the inputs ranges = tf.math.reduce_max(tf.math.abs(inputs), self._axis) if self.axis in [-2, 2] and inputs.shape[-1] == 1 and len(inputs.shape) > 2: # Expand the shape of the ranges so that it is broacastable on the inputs ranges = tf.expand_dims(ranges, -1) if self.fp_quantizer: # Build a FixedPoint when scale is disabled frac_bits = tf.stop_gradient(FixedPoint.max_frac_bits(self.value_bits, ranges)) qweights = FixedPoint.quantize(inputs, self.value_bits, frac_bits) else: # Evaluate the scales to align on the optimal quantization ranges scales = QFloat.optimal_scales(ranges, self.value_bits) # Clip scales lower bound to avoid quantizing very tiny values. Minimum is defined # as 1e-6, value low enough to be considered as zero. scales = tf.maximum(scales, 1e-6) # Since we use the optimal quantization ranges [-int_max -1, int_max], the inner # FixedPoint can be quantized with exactly zero fractional bits qweights = QFloat.quantize(inputs, self.value_bits, scales, 0.) # Record the quantized weights (it does nothing if recording is disabled) self.qweights(qweights) return qweights
[docs] def get_config(self): """Get the config of the layer. Returns: dict: the config of the layer. """ config = super().get_config() config.update({"axis": self.axis, "fp_quantizer": self.fp_quantizer}) return config