#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Rescaling transformation for quantized models.
"""
__all__ = ["align_rescaling"]
import numpy as np
from copy import deepcopy
from keras.layers import Rescaling, Conv2D, Dense
from keras.saving import serialize_keras_object
from .transforms_utils import get_layers_by_type, get_layers
from ..utils import apply_weights_to_model
from ...layers.convolution import PaddedConv2D
def _find_rescaling_fold_target(rescaling):
""" Find the folding target and check limitations.
Args:
rescaling (keras.layers.Layer): the rescaling layer
Returns:
keras.layers.Layer: the layer that follows the Rescaling if a valid candidate, None
otherwise.
"""
# Define layers that can accept Rescaling alignment
supported_dst_layers = [Conv2D, Dense]
scale_per_axis = isinstance(rescaling.scale, (list, tuple)) and len(rescaling.scale) > 1
if isinstance(rescaling.offset, (list, tuple)):
zero_offset = all(offset == 0 for offset in rescaling.offset)
else:
zero_offset = rescaling.offset == 0
if not scale_per_axis and zero_offset:
# Rescaling is already aligned: nothing to do
return None
# Alignment is limited to single outbound node Rescaling layers
if len(rescaling.outbound_nodes) != 1:
raise ValueError("Found a non-aligned Rescaling layer in the model with multiple outbounds "
"which is not supported.")
# Retrieve the destination layer and check its type
dst_layer = rescaling.outbound_nodes[0].layer
if type(dst_layer) not in supported_dst_layers:
raise ValueError(f"Layer type {type(dst_layer)} after Rescaling not supported, must be in "
f"{supported_dst_layers}.")
return dst_layer
def _adapt_padding(model, offset, dst_layer):
""" Update padding scheme by replacing Conv2D with a PaddedConv2D and appropriate padding value.
Args:
model (keras.Model): the original model
offset (float, list, tuple): the offset to fold
dst_layer (keras.layer): the layer where offset or scale will be folded
Returns:
keras.Model: the updated model
"""
# Copy configuration before applying modifications
config = deepcopy(model.get_config())
# Fold rescaling by editing the model configuration
dst_config = get_layers(config, [dst_layer.name])[0]
# Force bias
dst_config['config']['use_bias'] = True
# Replace Conv2D with 'same' padding by PaddedConv2D with correct padding value
if isinstance(dst_layer, Conv2D) and dst_layer.padding.lower() == 'same':
if isinstance(offset, (list, tuple)):
chan = dst_layer.get_weights()[0].shape[2]
assert len(offset) in (1, chan), "offset must be a scalar or of size of channels"
pad_values = [float(-p) for p in offset]
else:
pad_values = float(-offset)
new_config = PaddedConv2D.from_config(dst_config['config'])
dst_config.update(serialize_keras_object(new_config))
dst_config['config']['padding_value'] = pad_values
# Reconstruct model from the config
aligned_model = model.from_config(config)
# Restore model weights
variables_dict = {var.name: var for var in model.variables}
apply_weights_to_model(aligned_model, variables_dict, False)
return aligned_model
def _fold_rescaling(rescaling_layer, dst_layer, had_bias):
""" Fold rescaling parameters into dst_layer weights and bias.
Note that if scales are per-channel, there are folded in the weights.
Args:
rescaling_layer (keras.layers.Layer): the Rescaling layer
dst_layer (keras.layers.Layer): the layer where Rescaling is folded
had_bias (bool): whether the original layer had a bias or not
"""
assert isinstance(dst_layer, (Conv2D, Dense))
base_weights = dst_layer.get_weights()
new_w = base_weights[0].copy()
filters = new_w.shape[-1]
scale = rescaling_layer.scale
if isinstance(scale, (list, tuple)) and len(scale) > 1:
# If scale is not a scalar, align it
rescaling_layer.scale = 1
for i in range(filters):
# To compensate, adjust the weights of the next layer
new_w[..., i] *= scale
if isinstance(dst_layer, PaddedConv2D):
# Also rescale the padding value
padding_value = np.array(dst_layer._padding_value, dtype=np.float32)
dst_layer._padding_value = list(padding_value / scale)
new_weights = [new_w]
if dst_layer.use_bias:
# Build zero initialized biases if the original layer didn't have any
new_biases = base_weights[1].copy() if had_bias else np.zeros(filters)
for i in range(filters):
# Rescale biases filter by filter to enable broadcast if offsets are per channel
w_i = base_weights[0][..., i]
new_biases[i] += np.sum(w_i * rescaling_layer.offset)
new_weights += [new_biases]
rescaling_layer.offset = 0
dst_layer.set_weights(new_weights)
[docs]
def align_rescaling(model):
"""Aligns the Rescaling layer of the model to make it quantization ready.
This folds the offset into the bias of next layer.
The resulting Rescaling is therefore compatible with a quantization to a
QuantizedRescaling.
If the source model does not contain a Rescaling or if its Rescaling is already
aligned, then the original model is returned.
Args:
model (keras.Model): the source Keras model
Returns:
keras.Model: the original model or a new model with Rescaling layer aligned
"""
# Check if the model has a Rescaling layer and return the original model if not
rescaling_layer = get_layers_by_type(model, Rescaling)
if not rescaling_layer:
return model
# Limit alignment to the first rescaling layer (a model should only have one)
rescaling_layer = rescaling_layer[0]
# Find folding target and check limitations
dst_layer = _find_rescaling_fold_target(rescaling_layer)
# If no folding target was found return the original model
if dst_layer is None:
return model
# There is a rescaling offset, dst_layer padding scheme must be updated
offset = rescaling_layer.offset
aligned_model = _adapt_padding(model, offset, dst_layer)
# Fold Rescaling parameters into the new layer weights
_fold_rescaling(aligned_model.get_layer(rescaling_layer.name),
aligned_model.get_layer(dst_layer.name),
dst_layer.use_bias)
return aligned_model