Source code for quantizeml.layers.attention

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["Attention", "string_to_softmax", "QuantizedAttention"]

import keras
import tensorflow as tf

from .layers_base import (register_quantize_target, register_aligned_inputs, apply_buffer_bitwidth,
                          init_quant_config)
from .reshaping import QuantizedReshape, QuantizedPermute
from .shiftmax import shiftmax, QuantizedShiftmax
from .recorders import TensorRecorder
from .quantizers import OutputQuantizer
from ..tensors import FixedPoint, round_log2, pow2


[docs]def string_to_softmax(s): """ Convert a string to a softmax function. Available options are 'softmax' for standard softmax, 'shiftmax' for shiftmax. Args: s (str): string to convert. Returns: A softmax function. """ if s == "softmax": return tf.nn.softmax if s in ("softmax2", "shiftmax"): return shiftmax raise NotImplementedError("softmax should be in ['softmax', 'shiftmax']" f" but received {s}.")
[docs]@tf.keras.utils.register_keras_serializable() class Attention(keras.layers.Layer): """Dot-product attention layer with configurable softmax. Inputs are a tuple of tensors: - a query tensor of shape [batch, tokens, hidden], - a key tensor of shape [batch, tokens, hidden], - a value tensor of shape [batch, tokens, hidden]. The calculation follows the steps: 1. Split query, key, value per attention heads q, k, v : [batch, tokens, hidden] -> [batch, token, num_heads, dim] 2. Calculate cross-token scores as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True) scores : [batch, num_heads, token, token] 3. Rescale score by dividing by the squared-root of dim. 4. Use scores to calculate a mask mask = softmax(scores) 5. Combine mask with value output = tf.matmul(mask, value) output: [batch, num_heads, token, dim] 6. Merge heads to get back to 2D output: [batch, num_heads, token, dim] -> [batch, token, hidden] Args: num_heads (int): the number of attention heads softmax (str, optional): 'softmax' or 'shiftmax'. Defaults to 'softmax' """ def __init__(self, num_heads, softmax="softmax", **kwargs): super().__init__(**kwargs) self.num_heads = num_heads self.softmax = softmax self.softmax_op = string_to_softmax(self.softmax) def build(self, input_shape): super().build(input_shape) assert len(input_shape) == 3 self.hidden_size = input_shape[0][-1] if self.hidden_size % self.num_heads != 0: raise ValueError( f"Embedding dimension = {self.hidden_size} should be divisible" f" by number of heads = {self.num_heads}" ) self.dim = self.hidden_size // self.num_heads # Attention replace score / scale by a shift. # For that, we need to calculate the shift_scale scale = tf.math.sqrt(tf.cast(self.dim, dtype=tf.float32)) self.scale_shift = round_log2(scale) def separate_heads(self, x): x = keras.layers.Reshape((-1, self.num_heads, self.dim))(x) return keras.layers.Permute((2, 1, 3))(x) def call(self, inputs): # Separate 2D embeddings per head to obtain 3D inputs query = self.separate_heads(inputs[0]) key = self.separate_heads(inputs[1]) value = self.separate_heads(inputs[2]) # Dot product query and key for each head and pairs of tokens score = tf.matmul(query, key, transpose_b=True) # Rescale the corresponding score, dividing it by 2**scale_shift scaled_score = score * pow2(-self.scale_shift) # Apply the configurable softmax operation mask = self.softmax_op(scaled_score, axis=-1) # Combine each score with value to obtain new embeddings per tokens output = tf.matmul(mask, value) # Join heads to get back to 2D embeddings per token output = keras.layers.Permute((2, 1, 3))(output) output = keras.layers.Reshape((-1, self.hidden_size))(output) return output, mask def get_config(self): config = super().get_config() softmax = self.softmax if self.softmax == 'softmax2': # softmax2 is the legacy name, use shiftmax now softmax = 'shiftmax' config.update( { "num_heads": self.num_heads, "softmax": softmax } ) return config
[docs]@register_quantize_target(Attention) @register_aligned_inputs @tf.keras.utils.register_keras_serializable() class QuantizedAttention(Attention): """An Attention layer that operates on quantized inputs Args: num_heads (int): the number of attention heads quant_config (dict, optional): the serialized quantization configuration. Defaults to None. softmax (str, optional): 'softmax' or 'shiftmax'. Defaults to 'shiftmax' """ def __init__(self, num_heads, quant_config=None, softmax='shiftmax', **kwargs): if softmax != 'shiftmax': raise ValueError( "Only shiftmax is supported for quantized attention") super().__init__(num_heads=num_heads, softmax=softmax, **kwargs) self.quant_config = init_quant_config(quant_config) out_quant_cfg = self.quant_config.get("output_quantizer", False) if out_quant_cfg: self.out_quantizer = OutputQuantizer( name="output_quantizer", **out_quant_cfg) else: self.out_quantizer = None # Override softmax operation softmax_quant_conf = self.quant_config.get("softmax", None) self.softmax_op = QuantizedShiftmax( quant_config=softmax_quant_conf, name="quantized_softmax") self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True) # Add objects that will store the shift values. self.values_shift = TensorRecorder(self.name + "/value_shift") def separate_heads(self, x): x = QuantizedReshape((-1, self.num_heads, self.dim))(x) return QuantizedPermute((2, 1, 3))(x) def call(self, inputs): if any(not isinstance(x, FixedPoint) for x in inputs): # If any of the inputs is not a FixedPoint, raise an error raise ValueError("QuantizedAttention only accepts FixedPoint inputs") # Separate 2D embeddings per head to obtain 3D inputs query = self.separate_heads(inputs[0]) key = self.separate_heads(inputs[1]) # Expand the values to a higher bitwidth to avoid saturation and align them value, vshift = inputs[2].expand(self.buffer_bitwidth) self.values_shift(vshift) value = self.separate_heads(value) # Promote query to avoid saturation query = query.promote(self.buffer_bitwidth) # Dot product query and key for each head and pairs of tokens score = tf.matmul(query, key, transpose_b=True) # Rescale the corresponding score, dividing it by 2**scale_shift scaled_score = score >> self.scale_shift # Apply the configurable softmax operation mask = self.softmax_op(scaled_score) # Promote mask to make sure we don't overflow mask = mask.promote(self.buffer_bitwidth) # Combine each score with value to obtain new embeddings per tokens output = tf.matmul(mask, value) # Join heads to get back to 2D embeddings per token output = QuantizedPermute((2, 1, 3))(output) output = QuantizedReshape((-1, self.hidden_size))(output) # Refine output bitwidth precision if needed if self.out_quantizer is not None: output = self.out_quantizer(output) return output, mask def get_config(self): config = super().get_config() config["quant_config"] = self.quant_config return config