Source code for quantizeml.layers.activations

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["hard_silu", "QuantizedReLU", "QuantizedActivation"]

import os
import numpy as np
import tensorflow as tf
import keras

from .layers_base import (register_quantize_target, rescale_outputs, register_aligned_inputs,
                          tensor_inputs, apply_buffer_bitwidth, QuantizedLayer)
from .quantizers import AlignedWeightQuantizer, OutputQuantizer
from .recorders import NonTrackVariable
from ..tensors import FixedPoint, QFloat, QTensor

LUT_ENV = "LUT_ENABLED"


@keras.saving.register_keras_serializable()
def hard_silu(x):
    """Hard SiLU activation function, also known as Hard Swish.

    It is defined as:

    - `0` if `if x < -3`
    - `x` if `x > 3`
    - `x * (x + 3) / 6` if `-3 <= x <= 3`

    It's a faster, piecewise linear approximation of the silu activation.

    Args:
        x(tf.Tensor): Input tensor.

    Returns:
        tf.Tensor: Output tensor.

    Reference:
        [A Howard, 2019](https://arxiv.org/abs/1905.02244)
    """
    return x * tf.nn.relu6(x + 3) / 6


[docs] @register_quantize_target(keras.layers.ReLU) @keras.saving.register_keras_serializable() class QuantizedReLU(QuantizedLayer): """Quantized version of the ReLU activation layer applicable on FixedPoint tensor. Args: max_value (float, optional): ReLU maximum value. Defaults to 6. quant_config (dict, optional): the serialized quantization configuration. Defaults to None. """ arg_constraints = { 'negative_slope': 0, 'threshold': 0} ignored_args = ['negative_slope', 'threshold'] def __init__(self, *args, max_value=6, quant_config=None, **kwargs): super().__init__(*args, quant_config=quant_config, **kwargs) # Use quant_config to build quantizers out_quant_cfg = self.quant_config.get("output_quantizer", False) if out_quant_cfg: self.out_quantizer = OutputQuantizer(name="output_quantizer", **out_quant_cfg) else: self.out_quantizer = None self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=False) if max_value is not None: # Store max_value if isinstance(max_value, np.ndarray): max_value = max_value.item() max_value_quantizer_cfg = self.quant_config.get("max_value_quantizer", {}) self.max_value_quantizer = AlignedWeightQuantizer(name="max_value_quantizer", signed=False, **max_value_quantizer_cfg) self.max_value = max_value @tensor_inputs([QTensor]) @rescale_outputs def call(self, inputs): """ReLU activation function. In other terms: 1. clip the value between 0 and :attr:`max_value`. 2. quantize the output if an output_quantizer is set. Args: inputs (:obj:`QFloat`): the inputs tensor. Returns: :obj:`FixedPoint`: QuantizedReLU outputs. """ if isinstance(inputs, FixedPoint): # if inputs is FixedPoint, create an equivalent QFloat with scale # set to 1 inputs = QFloat(inputs, tf.constant(1.)) # Express zero as a QFloat aligned with the inputs because this is what the # dispatched operations expect. # The actual hardware implementation will simply use a zero integer. zero = QFloat(FixedPoint(tf.constant(0.), inputs.fp.value_bits, inputs.fp.frac_bits), inputs.scales) if self.max_value is None: # Just remove negative values return tf.math.maximum(inputs, zero) # Quantize and align max_value with the inputs max_value = self.max_value_quantizer(tf.cast(self.max_value, tf.float32), inputs) # Clip the inputs return tf.clip_by_value(inputs, zero, max_value) def get_config(self): config = super().get_config() config.update({"max_value": self.max_value}) return config
@keras.saving.register_keras_serializable() @register_quantize_target(keras.layers.Activation) @register_aligned_inputs class QuantizedActivation(QuantizedLayer, keras.layers.Activation): """Quantized version of `keras.layers.Activation` layer applicable on ``FixedPoint`` tensor. The input values are mapped through a look-up-table that simulates the activation behavior. Example >>> # Represent 2.5 as a FixedPoint >>> input = FixedPoint(5, value_bits=3, frac_bits=1) >>> # QuantizedActivation.call() maps `input` through the table to obtain >>> # an integer that represent the float value of tf.nn.gelu(2.5) >>> output = QuantizedActivation(activation="gelu")(input) >>> assert output.values == 80 >>> # Or which is equivalent in float domain >>> max_error = 2 ** -(output.frac_bits + 1) >>> assert tf.abs(output.to_float() - tf.nn.gelu(2.5)) < max_error Args: activation (callable or str): Activation function. It could be a callable, or the name of an activation from the keras.activations namespace. quant_config (dict, optional): the serialized quantization configuration. Defaults to None. """ arg_constraints = {'activation': lambda: ["gelu", "swish", "Custom>hard_silu"]} DEFAULT_INPUT_BITWIDTH = 11 DEFAULT_LUT_BITWIDTH = 14 def __init__(self, activation, *args, quant_config=None, **kwargs): super().__init__(activation, *args, quant_config=quant_config, **kwargs) # Retrieve quantization parameters if "lut_bitwidth" not in self.quant_config: self.quant_config["lut_bitwidth"] = self.DEFAULT_LUT_BITWIDTH self.lut_bits = self.quant_config["lut_bitwidth"] - 1 # Use quant_config to build quantizers out_quant_cfg = self.quant_config.get("output_quantizer", False) if out_quant_cfg: self.out_quantizer = OutputQuantizer(name="output_quantizer", **out_quant_cfg) else: self.out_quantizer = None # Create dynamic table and variable to save output frac_bits self.values = tf.lookup.experimental.MutableHashTable(key_dtype=tf.int32, value_dtype=tf.float32, default_value=2**self.lut_bits + 1, name="values_table") self.frac_bits = NonTrackVariable("frac_bits") @property def using_lut(self): """Flag to specify if the inference should be given through look-up-table approach Returns: bool: True if lut is enabled, False otherwise. """ value = os.environ.get(LUT_ENV, "0") return value == "1" @tf.function def record_values_in_table(self, value_bits, frac_bits): """Generate a set of inputs and outputs to record the look-up-table. Inputs are generated in the full range based on ``value_bits``. Args: value_bits (int): bits to define the range of values to be generated. frac_bits (tf.Tensor): frac_bits to convert the generated values in a FixedPoint. Returns: tf.Tensor: the expected frac_bits representing the values contained in the table. """ # Generates the full range values between [-(2 ** value_bits), 2 ** value_bits - 1]. int_max = 2 ** value_bits values = tf.range(-int_max, int_max, dtype=tf.int32) inputs = FixedPoint(values, value_bits=value_bits, frac_bits=frac_bits) # Forward float inputs through activation x = inputs.to_float() y = self.activation(x) # Apply dynamic quantization to compute output integer values. range_max = tf.reduce_max(tf.abs(y)) out_frac_bits = tf.stop_gradient(FixedPoint.max_frac_bits(self.lut_bits, range_max)) outputs = FixedPoint.quantize(y, self.lut_bits, frac_bits=out_frac_bits) # Insert values in table self.values.insert(values, outputs.values) # Return the static output frac_bits return out_frac_bits @tensor_inputs([FixedPoint]) @rescale_outputs def call(self, inputs): # Values stored in the table can only be calculated if the input is per_tensor if not inputs.per_tensor: raise TypeError(f"{self.__class__.__name__} only supports per-tensor inputs.") self.frac_bits.init_var(inputs.frac_bits) # Set values in table if table is empty (low computational cost) if self.values.size() == 0: out_frac_bits = self.record_values_in_table(inputs.value_bits, inputs.frac_bits) out_frac_bits = tf.stop_gradient(out_frac_bits) self.frac_bits.set_var(out_frac_bits) # Look-up-table has a high cost in inference. # It is possible to increase the speed if we avoid it through a # DeQuantization-reQuantization (DQ-Q) approach: if not self.using_lut: # 1. Dequantize the inputs x = inputs.to_float() # 2. Apply activation in float domain x = self.activation(x) # 3. Requantize from static quantization approach outputs = FixedPoint.quantize(x, value_bits=self.lut_bits, frac_bits=self.frac_bits.var) else: # Forward inputs.values into values table inputs = tf.cast(inputs.values, tf.int32) values = self.values.lookup(inputs) # MutableHashTable forgets the output shape. That is why we set it explicitly. values.set_shape(inputs.shape) # Build the output FixedPoint outputs = FixedPoint(values, value_bits=self.lut_bits, frac_bits=self.frac_bits.var) return outputs