Source code for cnn2snn.quantizeml.compatibility_checks

# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Functions to check model compatibility for CNN2SNN conversion.
"""
from keras import layers
from quantizeml import layers as qlayers

from .blocks import split_model_into_blocks
from ..transforms.sequential import _check_layers_data_format, _check_layer_inbounds
from ..akida_versions import get_akida_version, AkidaVersion
from .block_converter import (_V1_PATTERN_CONVERTERS, _V2_PATTERN_CONVERTERS,
                              _V1_INPUT_PATTERN_CONVERTERS, _V2_INPUT_PATTERN_CONVERTERS)

neural_layers = (qlayers.QuantizedConv2D, qlayers.QuantizedSeparableConv2D,
                 qlayers.QuantizedDense, qlayers.QuantizedDepthwiseConv2D,
                 qlayers.QuantizedConv2DTranspose, qlayers.QuantizedDepthwiseConv2DTranspose)
skippable_layers = (layers.InputLayer, layers.Rescaling, layers.Activation, layers.Softmax,
                    layers.Dropout)
pooling_layers = (qlayers.QuantizedGlobalAveragePooling2D, qlayers.QuantizedMaxPool2D)
norm_layers = (qlayers.LayerMadNormalization, qlayers.QuantizedBatchNormalization)
activation_layers = (qlayers.QuantizedReLU,)
stem_layers = (qlayers.ClassToken, qlayers.AddPositionEmbs)
reshape_layers = (qlayers.QuantizedReshape, qlayers.QuantizedFlatten)


def _block_pattern(block):
    """Method that returns the pattern of a block of layers.

    Args:
        block (list): list of quantized quantizeml layers.

    Returns:
        tuple: list of layer types representing the block pattern.
    """
    return tuple([layer.__class__ for layer in block])


def _get_block_converter(block):
    """Helper to get the BlockConverter of a block of layers.

    Args:
        block (list): list of quantized quantizeml layers.

    Returns:
        (:obj:`BlockConverter`): the BlockConverter corresponding to the block of layers or None.
    """
    pattern = _block_pattern(block)

    if get_akida_version() == AkidaVersion.v1:
        block_converter = _V1_PATTERN_CONVERTERS.get(pattern, None)
    else:
        block_converter = _V2_PATTERN_CONVERTERS.get(pattern, None)

    return block_converter


def _get_input_block_converter(block):
    """Helper to get the BlockConverter of an input block of layers.

    Args:
        block (list): list of quantized quantizeml layers.

    Returns:
        (:obj:`BlockConverter`): the BlockConverter corresponding to the block of layers or None.
    """
    pattern = _block_pattern(block)

    if get_akida_version() == AkidaVersion.v1:
        return _V1_INPUT_PATTERN_CONVERTERS.get(pattern, None)
    return _V2_INPUT_PATTERN_CONVERTERS.get(pattern, None)


[docs]def check_model_compatibility(model): r"""Checks if a QuantizeML model is compatible for Akida conversion. This function does NOT: - convert the QuantizeML model to an Akida model, - check if the model is compatible with Akida hardware It ONLY checks if the model design is compatible with Akida. Args: model (:obj:`tf.keras.Model`): the model to check. Returns: list: a list of sequences of the non_skippable layers ('blocks'). """ # Check general rules about model in three steps: # 1. Check if model has only one input and one output, # 2. Check right data format and # 3. Over Akida 1.0, check if model is sequential. _check_model_input_output(model) _check_layers_data_format(model) if get_akida_version() == AkidaVersion.v1: _check_layer_inbounds(model) # Split model into theirs blocks: blocks = split_model_into_blocks(model) # This list will contains either a block converter instance, # or a list of non-skippable layers. straight_blocks = [] # Evaluate block-by-block integrity for id, block in enumerate(blocks): # Initialize block_converter to None block_converter = None # Split blocks into skippable and none skippable blocks _, non_skippable = _extract_skippable_layers(block) # Skip the block if the block contains only skippable layers if len(non_skippable) == 0: continue # Get the corresponding BlockConverter of the layers block if available. # The first block is a special case and might target the HRC if fulfill some conditions if id == 0 and model.input_shape[-1] in (1, 3): block_converter = _get_input_block_converter(non_skippable) # If the first block doesn't match any Input pattern follow the classical way if not block_converter: block_converter = _get_block_converter(non_skippable) # One shouldn't get in here. If so the block pattern is unconvertible. Raise an error. if not block_converter: raise RuntimeError(f"Invalid block pattern to conversion. Receives {non_skippable}") straight_blocks.append(block_converter(non_skippable)) return straight_blocks
def _check_model_input_output(model): """Asserts that model inputs/outputs are supported for conversion. The Keras model must have only one input layer and one output layer. On Akida 1.0, the input shape must 4-D (N, H, W, C). Args: model (tf.keras.model): the Keras model to check. """ # Error if multiple inputs if len(model.input_names) > 1: raise RuntimeError("Model must have only one input layer. Receives " f"inputs {model.input_names}.") # Error if multiple outputs if len(model.output_names) != 1: raise RuntimeError("Model must have only one output layer. Receives" f"outputs {model.output_names}.") # Error if input shape is not 2D or 4D if len(model.input_shape) not in (2, 4) and get_akida_version() == AkidaVersion.v1: raise RuntimeError( "Input shape of model must be 2-D or 4-D (batch size + 1-D or 3-D " f"tensors). Receives input shape {model.input_shape}.") # In akida HW one can realise skip connection directly with the main layer if len(model.layers[0].outbound_nodes) > 1: raise RuntimeError("The input model layer can only have one outbound node.") def _extract_skippable_layers(block): """Split block into skippable and non skippable layers Args: block (tf.keras.Layer): block to split. Returns: tuple: list of skippable and non skippable layers """ skippable, non_skippable = [], [] for layer in block: if isinstance(layer, skippable_layers): skippable.append(layer) elif isinstance(layer, qlayers.QuantizedReshape): in_shape = layer.input_shape out_shape = layer.output_shape in_dims = [x for x in in_shape if x != 1] out_dims = [x for x in out_shape if x != 1] if in_dims != out_dims: non_skippable.append(layer) else: skippable.append(layer) else: non_skippable.append(layer) return skippable, non_skippable