# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Functions to check model compatibility for CNN2SNN conversion.
"""
from keras import layers
from quantizeml import layers as qlayers
from .blocks import split_model_into_blocks
from ..transforms.sequential import _check_layers_data_format, _check_layer_inbounds
from ..akida_versions import get_akida_version, AkidaVersion
from .block_converter import (_V1_PATTERN_CONVERTERS, _V2_PATTERN_CONVERTERS,
_V1_INPUT_PATTERN_CONVERTERS, _V2_INPUT_PATTERN_CONVERTERS)
neural_layers = (qlayers.QuantizedConv2D, qlayers.QuantizedSeparableConv2D,
qlayers.QuantizedDense, qlayers.QuantizedDepthwiseConv2D,
qlayers.QuantizedConv2DTranspose, qlayers.QuantizedDepthwiseConv2DTranspose)
skippable_layers = (layers.InputLayer, layers.Rescaling, layers.Activation, layers.Softmax,
layers.Dropout)
pooling_layers = (qlayers.QuantizedGlobalAveragePooling2D, qlayers.QuantizedMaxPool2D)
norm_layers = (qlayers.LayerMadNormalization, qlayers.QuantizedBatchNormalization)
activation_layers = (qlayers.QuantizedReLU,)
stem_layers = (qlayers.ClassToken, qlayers.AddPositionEmbs)
reshape_layers = (qlayers.QuantizedReshape, qlayers.QuantizedFlatten)
def _block_pattern(block):
"""Method that returns the pattern of a block of layers.
Args:
block (list): list of quantized quantizeml layers.
Returns:
tuple: list of layer types representing the block pattern.
"""
return tuple([layer.__class__ for layer in block])
def _get_block_converter(block):
"""Helper to get the BlockConverter of a block of layers.
Args:
block (list): list of quantized quantizeml layers.
Returns:
(:obj:`BlockConverter`): the BlockConverter corresponding to the block of layers or None.
"""
pattern = _block_pattern(block)
if get_akida_version() == AkidaVersion.v1:
block_converter = _V1_PATTERN_CONVERTERS.get(pattern, None)
else:
block_converter = _V2_PATTERN_CONVERTERS.get(pattern, None)
return block_converter
def _get_input_block_converter(block):
"""Helper to get the BlockConverter of an input block of layers.
Args:
block (list): list of quantized quantizeml layers.
Returns:
(:obj:`BlockConverter`): the BlockConverter corresponding to the block of layers or None.
"""
pattern = _block_pattern(block)
if get_akida_version() == AkidaVersion.v1:
return _V1_INPUT_PATTERN_CONVERTERS.get(pattern, None)
return _V2_INPUT_PATTERN_CONVERTERS.get(pattern, None)
[docs]def check_model_compatibility(model):
r"""Checks if a QuantizeML model is compatible for Akida conversion.
This function does NOT:
- convert the QuantizeML model to an Akida model,
- check if the model is compatible with Akida hardware
It ONLY checks if the model design is compatible with Akida.
Args:
model (:obj:`tf.keras.Model`): the model to check.
Returns:
list: a list of sequences of the non_skippable layers ('blocks').
"""
# Check general rules about model in three steps:
# 1. Check if model has only one input and one output,
# 2. Check right data format and
# 3. Over Akida 1.0, check if model is sequential.
_check_model_input_output(model)
_check_layers_data_format(model)
if get_akida_version() == AkidaVersion.v1:
_check_layer_inbounds(model)
# Split model into theirs blocks:
blocks = split_model_into_blocks(model)
# This list will contains either a block converter instance,
# or a list of non-skippable layers.
straight_blocks = []
# Evaluate block-by-block integrity
for id, block in enumerate(blocks):
# Initialize block_converter to None
block_converter = None
# Split blocks into skippable and none skippable blocks
_, non_skippable = _extract_skippable_layers(block)
# Skip the block if the block contains only skippable layers
if len(non_skippable) == 0:
continue
# Get the corresponding BlockConverter of the layers block if available.
# The first block is a special case and might target the HRC if fulfill some conditions
if id == 0 and model.input_shape[-1] in (1, 3):
block_converter = _get_input_block_converter(non_skippable)
# If the first block doesn't match any Input pattern follow the classical way
if not block_converter:
block_converter = _get_block_converter(non_skippable)
# One shouldn't get in here. If so the block pattern is unconvertible. Raise an error.
if not block_converter:
raise RuntimeError(f"Invalid block pattern to conversion. Receives {non_skippable}")
straight_blocks.append(block_converter(non_skippable))
return straight_blocks
def _check_model_input_output(model):
"""Asserts that model inputs/outputs are supported for conversion.
The Keras model must have only one input layer and one output layer.
On Akida 1.0, the input shape must 4-D (N, H, W, C).
Args:
model (tf.keras.model): the Keras model to check.
"""
# Error if multiple inputs
if len(model.input_names) > 1:
raise RuntimeError("Model must have only one input layer. Receives "
f"inputs {model.input_names}.")
# Error if multiple outputs
if len(model.output_names) != 1:
raise RuntimeError("Model must have only one output layer. Receives"
f"outputs {model.output_names}.")
# Error if input shape is not 2D or 4D
if len(model.input_shape) not in (2, 4) and get_akida_version() == AkidaVersion.v1:
raise RuntimeError(
"Input shape of model must be 2-D or 4-D (batch size + 1-D or 3-D "
f"tensors). Receives input shape {model.input_shape}.")
# In akida HW one can realise skip connection directly with the main layer
if len(model.layers[0].outbound_nodes) > 1:
raise RuntimeError("The input model layer can only have one outbound node.")
def _extract_skippable_layers(block):
"""Split block into skippable and non skippable layers
Args:
block (tf.keras.Layer): block to split.
Returns:
tuple: list of skippable and non skippable layers
"""
skippable, non_skippable = [], []
for layer in block:
if isinstance(layer, skippable_layers):
skippable.append(layer)
elif isinstance(layer, qlayers.QuantizedReshape):
in_shape = layer.input_shape
out_shape = layer.output_shape
in_dims = [x for x in in_shape if x != 1]
out_dims = [x for x in out_shape if x != 1]
if in_dims != out_dims:
non_skippable.append(layer)
else:
skippable.append(layer)
else:
non_skippable.append(layer)
return skippable, non_skippable