Note
Go to the end to download the full example code
Advanced ONNX models quantization
Akida, like any specialized hardware accelerator, sacrifices very generalized computational ability in favor of highly optimized implementations of a subset of key operations. While we strive to make sure that Akida directly supports the most important models, it isn’t feasible to support all possibilities. You may thus occasionally find yourself with a model which is very nearly compatible with Akida, but which fails to convert due to just a few incompatibilities. In this example, we will look at some simple workarounds and how to implement them. The goal is to successfully convert the model to Akida without having to retrain.
Preparing a model for Akida requires two steps: quantization, followed by conversion for a specific target hardware device. We try to catch as many incompatibilities as possible at the quantization step. However, some constraints depend on the specific target device, and can only be caught at the conversion step. To illustrate, we will simply walk through the process of preparing ResNet50 for acceleration on Akida - we’ll run into several incompatibilities at different points in that process, and see how to resolve them.
This example assumes a moderate level of experience with deep learning, and good familiarity with the operations typically encountered in these types of models. For example, here we’ll use the following workarounds:
to avoid some incompatible sequences of operations we’ll insert layers with “identity” convolution kernels,
in order to avoid an unusual kernel-size 1/stride 2 convolution, we’ll substitute those kernels with equivalent size 3 kernels.
1. Get model and data
Before diving into the model incompatibilities and how to resolve them, we’ll need to acquire some sample data to test on, plus the pretrained model.
1.1 Data
Given that the reference model was trained on ImageNet dataset (which is not publicly available), this tutorial uses a set of 10 copyright free images. See data preparation for more details.
import os
import csv
import numpy as np
from tensorflow.io import read_file
from tensorflow.image import decode_jpeg
from tensorflow.keras.utils import get_file
from akida_models.imagenet import preprocessing
from akida_models.imagenet.imagenet_utils import IMAGENET_MEAN, IMAGENET_STD
# Model specification and hyperparameters
NUM_CHANNELS = 3
IMAGE_SIZE = 224
num_images = 10
# Retrieve dataset file from Brainchip data server
file_path = get_file(
"imagenet_like.zip",
"https://data.brainchip.com/dataset-mirror/imagenet_like/imagenet_like.zip",
cache_subdir='datasets/imagenet_like',
extract=True)
data_folder = os.path.dirname(file_path)
# Load images for test set
x_test_files = []
x_test_raw = np.zeros((num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)).astype('uint8')
for id in range(num_images):
test_file = 'image_' + str(id + 1).zfill(2) + '.jpg'
x_test_files.append(test_file)
img_path = os.path.join(data_folder, test_file)
base_image = read_file(img_path)
image = decode_jpeg(base_image, channels=NUM_CHANNELS)
image = preprocessing.preprocess_image(image, (IMAGE_SIZE, IMAGE_SIZE))
x_test_raw[id, :, :, :] = np.expand_dims(image, axis=0)
# Parse labels file
fname = os.path.join(data_folder, 'labels_validation.txt')
validation_labels = dict()
with open(fname, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=' ')
for row in reader:
validation_labels[row[0]] = row[1]
# Get labels for the test set by index
labels_test = np.zeros(num_images)
for i in range(num_images):
labels_test[i] = int(validation_labels[x_test_files[i]])
# Normalize images as models expects
imagenet_mean_255 = np.array(IMAGENET_MEAN, dtype="float32") * 255.0
imagenet_std_255 = np.array(IMAGENET_STD, dtype="float32") * 255.0
x_test = ((x_test_raw - imagenet_mean_255) / imagenet_std_255)
# Transpose the channels to the first axis as per the default for ONNX models
x_test = np.transpose(x_test, (0, 3, 1, 2))
print(f'{num_images} images loaded and preprocessed.')
10 images loaded and preprocessed.
1.2 Download the model
We download ResNet50 from the ONNX ZOO,
import onnx
import onnx.hub
from onnxruntime import InferenceSession
onnx_model = onnx.hub.load("ResNet50")
Downloading ResNet50 to local path /root/.cache/onnx/hub/archive/vision/classification/resnet/model/af16a04a6ec48ac494065d4439fe9dea590d337b9ca6dc328160ccf04a217b9c_resnet50-v1-7.onnx
1.3 Evaluate model performance
The ONNXRuntime package is a cross-platform accelerator capable of loading and running models described in ONNX format. We use this framework to evaluate the performance of the loaded ResNet50 model.
Note
For example purposes, we only compute accuracy on 10 images. Accuracy on the full ImageNet validation set is reported at the end.
def evaluate_onnx_model(model):
sess = InferenceSession(model.SerializeToString())
# Calculate outputs by running images through the session
outputs = sess.run(None, {model.graph.input[0].name: x_test})
# The class with the highest score is what we choose as prediction
predicted = np.squeeze(np.argmax(outputs[0], 1))
# Compute the model accuracy
accuracy = (predicted == labels_test).sum() / num_images
return accuracy
# Evaluate over test dataset
accuracy_floating = evaluate_onnx_model(onnx_model)
print(f'Floating point model accuracy: {100 * accuracy_floating:.2f} %')
Floating point model accuracy: 100.00 %
2. Quantize
Akida processes integer inputs, activations and weights. Therefore, the first step in preparing a floating point model to run on Akida is to quantize it using QuantizeML quantize().
Note
Please refer to the QuantizeML toolkit user guide and the Advanced QuantizeML tutorial for further details. In particular here, for simplicity, we pass only the small number of samples we already prepared for calibration. Typically, you will want to use many more samples for calibration, say 1000 if you have them available; and not drawn from your test data. The akida_models package provides a helper function, extract_samples() which may be helpful in preparing those.
from quantizeml.models import quantize
model_quantized = quantize(onnx_model, samples=x_test)
Calibrating with 10/10.0 samples
/usr/local/lib/python3.8/dist-packages/quantizeml/onnx_support/quantization/quantize.py:208: UserWarning: The following nodes were not quantized because their pattern was not found in the scope: ['resnetv17_stage1_activation0 (Relu)', 'resnetv17_stage1_conv4_fwd (Conv)', 'resnetv17_stage1_relu2_fwd (Relu)', 'resnetv17_stage1_conv5_fwd (Conv)', 'resnetv17_stage1_relu3_fwd (Relu)', 'resnetv17_stage1_conv6_fwd (Conv)', 'resnetv17_stage1__plus1 (Add)', 'resnetv17_stage1_activation1 (Relu)', 'resnetv17_stage1_conv7_fwd (Conv)', 'resnetv17_stage1_relu4_fwd (Relu)', 'resnetv17_stage1_conv8_fwd (Conv)', 'resnetv17_stage1_relu5_fwd (Relu)', 'resnetv17_stage1_conv9_fwd (Conv)', 'resnetv17_stage1__plus2 (Add)', 'resnetv17_stage1_activation2 (Relu)', 'resnetv17_stage2_conv3_fwd (Conv)', 'resnetv17_stage2_conv0_fwd (Conv)', 'resnetv17_stage2_relu0_fwd (Relu)', 'resnetv17_stage2_conv1_fwd (Conv)', 'resnetv17_stage2_relu1_fwd (Relu)', 'resnetv17_stage2_conv2_fwd (Conv)', 'resnetv17_stage2__plus0 (Add)', 'resnetv17_stage2_activation0 (Relu)', 'resnetv17_stage2_conv4_fwd (Conv)', 'resnetv17_stage2_relu2_fwd (Relu)', 'resnetv17_stage2_conv5_fwd (Conv)', 'resnetv17_stage2_relu3_fwd (Relu)', 'resnetv17_stage2_conv6_fwd (Conv)', 'resnetv17_stage2__plus1 (Add)', 'resnetv17_stage2_activation1 (Relu)', 'resnetv17_stage2_conv7_fwd (Conv)', 'resnetv17_stage2_relu4_fwd (Relu)', 'resnetv17_stage2_conv8_fwd (Conv)', 'resnetv17_stage2_relu5_fwd (Relu)', 'resnetv17_stage2_conv9_fwd (Conv)', 'resnetv17_stage2__plus2 (Add)', 'resnetv17_stage2_activation2 (Relu)', 'resnetv17_stage2_conv10_fwd (Conv)', 'resnetv17_stage2_relu6_fwd (Relu)', 'resnetv17_stage2_conv11_fwd (Conv)', 'resnetv17_stage2_relu7_fwd (Relu)', 'resnetv17_stage2_conv12_fwd (Conv)', 'resnetv17_stage2__plus3 (Add)', 'resnetv17_stage2_activation3 (Relu)', 'resnetv17_stage3_conv3_fwd (Conv)', 'resnetv17_stage3_conv0_fwd (Conv)', 'resnetv17_stage3_relu0_fwd (Relu)', 'resnetv17_stage3_conv1_fwd (Conv)', 'resnetv17_stage3_relu1_fwd (Relu)', 'resnetv17_stage3_conv2_fwd (Conv)', 'resnetv17_stage3__plus0 (Add)', 'resnetv17_stage3_activation0 (Relu)', 'resnetv17_stage3_conv4_fwd (Conv)', 'resnetv17_stage3_relu2_fwd (Relu)', 'resnetv17_stage3_conv5_fwd (Conv)', 'resnetv17_stage3_relu3_fwd (Relu)', 'resnetv17_stage3_conv6_fwd (Conv)', 'resnetv17_stage3__plus1 (Add)', 'resnetv17_stage3_activation1 (Relu)', 'resnetv17_stage3_conv7_fwd (Conv)', 'resnetv17_stage3_relu4_fwd (Relu)', 'resnetv17_stage3_conv8_fwd (Conv)', 'resnetv17_stage3_relu5_fwd (Relu)', 'resnetv17_stage3_conv9_fwd (Conv)', 'resnetv17_stage3__plus2 (Add)', 'resnetv17_stage3_activation2 (Relu)', 'resnetv17_stage3_conv10_fwd (Conv)', 'resnetv17_stage3_relu6_fwd (Relu)', 'resnetv17_stage3_conv11_fwd (Conv)', 'resnetv17_stage3_relu7_fwd (Relu)', 'resnetv17_stage3_conv12_fwd (Conv)', 'resnetv17_stage3__plus3 (Add)', 'resnetv17_stage3_activation3 (Relu)', 'resnetv17_stage3_conv13_fwd (Conv)', 'resnetv17_stage3_relu8_fwd (Relu)', 'resnetv17_stage3_conv14_fwd (Conv)', 'resnetv17_stage3_relu9_fwd (Relu)', 'resnetv17_stage3_conv15_fwd (Conv)', 'resnetv17_stage3__plus4 (Add)', 'resnetv17_stage3_activation4 (Relu)', 'resnetv17_stage3_conv16_fwd (Conv)', 'resnetv17_stage3_relu10_fwd (Relu)', 'resnetv17_stage3_conv17_fwd (Conv)', 'resnetv17_stage3_relu11_fwd (Relu)', 'resnetv17_stage3_conv18_fwd (Conv)', 'resnetv17_stage3__plus5 (Add)', 'resnetv17_stage3_activation5 (Relu)', 'resnetv17_stage4_conv3_fwd (Conv)', 'resnetv17_stage4_conv0_fwd (Conv)', 'resnetv17_stage4_relu0_fwd (Relu)', 'resnetv17_stage4_conv1_fwd (Conv)', 'resnetv17_stage4_relu1_fwd (Relu)', 'resnetv17_stage4_conv2_fwd (Conv)', 'resnetv17_stage4__plus0 (Add)', 'resnetv17_stage4_activation0 (Relu)', 'resnetv17_stage4_conv4_fwd (Conv)', 'resnetv17_stage4_relu2_fwd (Relu)', 'resnetv17_stage4_conv5_fwd (Conv)', 'resnetv17_stage4_relu3_fwd (Relu)', 'resnetv17_stage4_conv6_fwd (Conv)', 'resnetv17_stage4__plus1 (Add)', 'resnetv17_stage4_activation1 (Relu)', 'resnetv17_stage4_conv7_fwd (Conv)', 'resnetv17_stage4_relu4_fwd (Relu)', 'resnetv17_stage4_conv8_fwd (Conv)', 'resnetv17_stage4_relu5_fwd (Relu)', 'resnetv17_stage4_conv9_fwd (Conv)', 'resnetv17_stage4__plus2 (Add)', 'resnetv17_stage4_activation2 (Relu)', 'resnetv17_pool1_fwd (GlobalAveragePool)', 'flatten_473 (Flatten)', 'resnetv17_dense0_fwd (Gemm)'].
warnings.warn("The following nodes were not quantized because their pattern was not found "
We can see that the model is not fully quantized, stopping at the first unrecognized
pattern (node resnetv17_stage1_activation0 (Relu)
). We know that Akida can definitely
handle ReLU activation functions, so we have to look more closely to understand the
problem. Analyzing the model, the ReLU immediately follows an Add
operator. It is
this sequence of operations which is not supported by Akida.
2.1 About Patterns
For efficiency, Akida hardware actually groups certain commonly occuring operations together. For example, ReLU activation functions, where present, are almost always applied on the outputs of the hard-working computational layers (Convolutions, Depthwise Convolutions, Dense layers etc.). So the ReLU on Akida is tied to those operations. While efficient, this does mean that some sequences of operations will not by default be considered Akida-compatible, even though the individual operations are known to be handled. That’s the cause of the problem encountered here.
To properly see what’s going on, and to resolve the problem, we’ll need to understand the concept of “patterns”. These are the objects that QuantizeML uses to map ONNX models to their Akida equivalents. A pattern is a sequence of continuous ONNX operators in a graph that can be converted to an Akida V2 layer. For example, the following model would be converted to an akida.InputConv2D layer:
The sequence of operators [Conv
, Clip
, MaxPool
] is one valid pattern
for conversion towards InputConv2D.
Crucially, we can check the list of the currently supported patterns:
from quantizeml.onnx_support.quantization.register_patterns import PATTERNS_MAP
print(*PATTERNS_MAP, sep='\n')
QuantizerPattern(pattern=('Conv', 'Relu', 'GlobalAveragePool'), f=<function get_qconv at 0x7fcaa3aee040>)
QuantizerPattern(pattern=('Conv', 'Relu', 'MaxPool'), f=<function get_qconv at 0x7fcaa3aee040>)
QuantizerPattern(pattern=('Conv', 'GlobalAveragePool'), f=<function get_qconv at 0x7fcaa3aee040>)
QuantizerPattern(pattern=('Conv', 'Relu'), f=<function get_qconv at 0x7fcaa3aee040>)
QuantizerPattern(pattern=('Conv',), f=<function get_qconv at 0x7fcaa3aee040>)
QuantizerPattern(pattern=('DepthwiseConv', 'Relu'), f=<function get_qdepthwise at 0x7fcaa3a834c0>)
QuantizerPattern(pattern=('DepthwiseConv',), f=<function get_qdepthwise at 0x7fcaa3a834c0>)
QuantizerPattern(pattern=('Flatten', 'Gemm', 'Relu'), f=<function get_qgemm at 0x7fcaa3a83820>)
QuantizerPattern(pattern=('Flatten', 'Gemm'), f=<function get_qgemm at 0x7fcaa3a83820>)
QuantizerPattern(pattern=('Gemm', 'Relu'), f=<function get_qgemm at 0x7fcaa3a83820>)
QuantizerPattern(pattern=('Gemm',), f=<function get_qgemm at 0x7fcaa3a83820>)
QuantizerPattern(pattern=('Add',), f=<function get_qadd at 0x7fcaa3a83c10>)
Looking at that list, it should be apparent that a ReLU
operation on its own or
following an Add
is not considered a compatible pattern.
Note
Before the conversion the following changes are automatically done to allow the QuantizeML toolkit to see an ONNX graph suitable for quantization:
transforms the following operators for general purposes:
Conv
->DepthwiseConv
when kernel size is 1 x Kx x Ky andgroup
is required
Clip
>Relu
(ifmin = 0.0
)uses Graph Optimizations in ONNX Runtime to optimize the graph (e.g. fuse BatchNorm into convolutions).
2.2. Custom quantization patterns
The existing patterns won’t allow us to map an isolated ReLU operation. But, for example,
the ReLU operation can be mapped when following a Conv layer, and we can easily implement
a Conv layer that performs an identity operation on its inputs, just by setting the kernel
weights appropriately. We can implement this workaround by using custom quantization
patterns to extend PATTERNS_MAP
.
Every pattern includes an ONNX layer that stores the ONNX graph information for the matching sequence of nodes. QuantizeML also allows for a function to create a compatible layer from an initially incompatible pattern. This pattern function has two input parameters: the graph and the pattern-matched sequence of nodes extracted from it.
Once a pattern function is defined for an unsupported pattern, both can be appended
in the quantization context through the custom_pattern_scope
function.
from quantizeml.onnx_support import layers
from quantizeml.onnx_support.quantization import custom_pattern_scope
class IdentityQuantizedConv2D(layers.QuantizedConv2D):
def __build__(self, input_ts, downscale=True):
# Produces a kernel such that the convolution does not modify the input.
identity_kernel = np.identity(input_ts.shape[1], "float32")[..., None, None]
self.set_weight("kernel", identity_kernel)
return super().__build__(input_ts, downscale)
def relu_pattern_fn(block_nodes, graph):
"""Convert the incompatible patterns ['Relu'] and ['Relu', 'GlobalAveragePool'] into
an IdentityQuantizedConv2D.
"""
# Note that as 'quantization_pattern_map' is written, this function expects to receive
# only the isolated ('Relu') that matches in the graph.
block_ops = [x.op_type for x in block_nodes]
if block_ops == ['Relu']:
return IdentityQuantizedConv2D(activation=True)
else:
raise Exception(f"Unrecognized pattern: {block_ops}")
# Define a custom patterns map, as a new pattern and associated replacement function.
relu_pattern_map = {
"Relu": relu_pattern_fn,
}
# Include relu_pattern_map in the quantization context
with custom_pattern_scope(relu_pattern_map):
model_quantized = quantize(onnx_model, samples=x_test)
Calibrating with 10/10.0 samples
/usr/local/lib/python3.8/dist-packages/quantizeml/onnx_support/quantization/quantize.py:208: UserWarning: The following nodes were not quantized because their pattern was not found in the scope: ['resnetv17_pool1_fwd (GlobalAveragePool)', 'flatten_473 (Flatten)', 'resnetv17_dense0_fwd (Gemm)'].
warnings.warn("The following nodes were not quantized because their pattern was not found "
With the isolated ReLU fixed, we managed to quantize much more of the model, but
we hit a new problem, node resnetv17_pool1_fwd (GlobalAveragePool)
. Looking back
at the list of compatible patterns, we can see that, like the ReLU, a GlobalAveragePooling
(GAP) operation cannot be handled in isolation, but is compatible when it follows
Conv or Conv + ReLU operations. The second of those will suit us better here,
that way we can combine it with our solution for the ReLU operation (because
the GAP here does indeed follow one of the isolated ReLU ops).
def activation_pattern_fn(block_nodes, graph):
"""Convert the incompatible patterns ['Relu'] and ['Relu', 'GlobalAveragePool'] into
an IdentityQuantizedConv2D.
"""
# Note that as 'quantization_pattern_map' is written, this function expects to receive
# only the sequences ('Relu') or ('Relu', 'GlobalAveragePool').
block_ops = [x.op_type for x in block_nodes]
if block_ops == ['Relu']:
return IdentityQuantizedConv2D(activation=True)
elif block_ops == ['Relu', 'GlobalAveragePool']:
return IdentityQuantizedConv2D(activation=True, pool_type="gap")
else:
raise Exception(f"Unrecognized pattern: {block_ops}")
# Define quantization custom patterns map, as a set of patterns and associated replacement function.
# activation_pattern_fn was designed to handle two similar incompatibilities present in ResNet50.
quantization_pattern_map = {
("Relu", "GlobalAveragePool"): activation_pattern_fn,
"Relu": activation_pattern_fn,
}
# Include quantization_pattern_map in the quantization context
with custom_pattern_scope(quantization_pattern_map):
model_quantized = quantize(onnx_model, samples=x_test)
Calibrating with 10/10.0 samples
The full model is now quantized successfully. At this point we can re-check its accuracy:
accuracy = evaluate_onnx_model(model_quantized)
print(f'Quantized model accuracy: {100 * accuracy:.2f} %')
Quantized model accuracy: 100.00 %
3. Conversion
3.1. Incompatibility at Conversion
As indicated above, while most imcompatibilities will be picked up at the quantization step, some constraints are specific to the target hardware device, and can only be applied at the conversion step. We can detect these either with the check_model_compatibility <../../api_reference/cnn2snn_apis.html#cnn2snn.check_model_compatibility>`__tool, or by trying to `convert the model into Akida.
from cnn2snn import convert
try:
akida_model = convert(model_quantized)
except Exception as e:
print(f"ResNet50 is not fully accelerated by Akida. Reason: {str(e)}")
ResNet50 is not fully accelerated by Akida. Reason: Expect pads [2, 2, 3, 3] (found [3, 3, 3, 3]) in resnetv17_conv0_fwd.
This error is raised because the ResNet50 padding scheme is very specific and differs from the Keras/Akida standard.
Ideally, we should aim to swap incompatible operations with mathematically equivalent replacements. For issues of convolution kernel size or padding, we can often achieve that by putting the kernel weights within a larger kernel, placed eccentrically to compensate for any padding issues etc. More on that below - but we can’t use that strategy here, because the kernel size for this layer (7x7) is already the maximum supported by the Akida input layer. In this case, we’ll have to try simply modifying the padding to be Akida-compatible. Because this is the input layer, we could actually negate that change by padding the input image along two edges before passing to Akida. However, precisely because this is the very start of the network, and the consequence is only a single pixel of spatial offset, we might expect that the impact on model performance will be negligible, and that’s precisely what we find on testing. So let’s keep things simple in this case: simply replace the incompatible values with compatible ones.
To achieve this, we’ll again customize the pattern functions to modify the model before quantization. Rather than try to provide a general solution, we’ll hard code this for the problem layer:
from quantizeml.onnx_support import graph_tools
def align_input_conv_with_akida(block_nodes, graph):
"""Pattern function that handles convolutions incompatible with Akida
"""
# Recover initial ONNXLayer from block nodes and graph
qconv = layers.get_qconv(block_nodes, graph)
# Force the pads in first convolution to Akida compatible values
if qconv.name == 'resnetv17_conv0_fwd':
print("Setting Akida pads in first convolution...")
# Note: pads in convolution include spatial dimension
qconv.set_weight("pads", np.array([0, 0, 2, 2, 0, 0, 3, 3]))
graph_tools.replace_field(qconv, "pool_pads", [0, 0, 1, 1])
return qconv
# Infer intermediate shape: This is required for some custom pattern functions
onnx_model_temp = onnx.shape_inference.infer_shapes(onnx_model)
# Quantize model with custom patterns
quantization_pattern_map = {
("Conv", "Relu", "MaxPool"): align_input_conv_with_akida,
("Conv", "Relu"): align_input_conv_with_akida,
("Conv",): align_input_conv_with_akida,
("Relu", "GlobalAveragePool"): activation_pattern_fn,
"Relu": activation_pattern_fn,
}
with custom_pattern_scope(quantization_pattern_map):
model_quantized = quantize(onnx_model_temp, samples=x_test)
Calibrating with 10/10.0 samples
Setting Akida pads in first convolution...
Let’s try to convert again:
try:
akida_model = convert(model_quantized)
except Exception as e:
print(f"ResNet50 is not fully accelerated by Akida. Reason: {str(e)}")
ResNet50 is not fully accelerated by Akida. Reason: Stride 2 is only supported with kernel size 3.
The error message now indicates that there is a problem with a stride 2 operation, because the required kernel size is not supported. Looking at the ResNet50 definition, we can see that there’s a very unusual kernel-size 1 / stride 2 conv operation applied within the stride-2 blocks at the beginning of each stage. That’s a good candidate for the workaround mentioned earlier: we can simply swap in a compatible kernel-size 3 / stride 2 convolution, placing the weights from the original size 1 kernel within (otherwise zero-valued) size 3 kernels. In Akida (and Keras), the kernel-size 3/ stride 2 conv operation has padding [0, 0, 1, 1], so to make the replacement operation equivalent we need to place the smaller kernel weights eccentrically within the larger kernels.
We’ll combine that with the previous padding fix within a single function.
def align_conv_with_akida(block_nodes, graph):
"""Pattern function that handles convolutions incompatible with Akida
"""
# Recover initial ONNXLayer from block nodes and graph
qconv = layers.get_qconv(block_nodes, graph)
# Force the pads in first convolution to Akida compatible values
if qconv.name == 'resnetv17_conv0_fwd':
print("Setting Akida pads in first convolution...")
# Note: pads in convolution include spatial dimension
qconv.set_weight("pads", np.array([0, 0, 2, 2, 0, 0, 3, 3]))
graph_tools.replace_field(qconv, "pool_pads", [0, 0, 1, 1])
# Replace 1x1 kernel with strides 2x2 by a padded 3x3 one
kernel = qconv.weights["kernel"]
strides = graph_tools.get_field(qconv, "strides")
if kernel.shape[-2:] == (1, 1) and strides == [2, 2]:
# Only spatial dimensions are padded
# The original weights are placed eccentrically in the new kernel
new_kernel = np.pad(kernel, ((0, 0), (0, 0), (0, 2), (0, 2)))
# Set new kernel in the layer
# This operation requires a specific padding pattern for Akida compatibility.
qconv.set_weight("pads", np.array([0, 0, 0, 0, 0, 0, 1, 1]))
qconv.set_weight("kernel", new_kernel)
print(f"Kernel updated in {qconv.name} from 1x1 to 3x3.")
return qconv
# Infer intermediate shape: This is required for some custom pattern functions
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# Quantize model with custom patterns
quantization_pattern_map = {
("Conv", "Relu", "MaxPool"): align_conv_with_akida,
("Conv", "Relu"): align_conv_with_akida,
("Conv",): align_conv_with_akida,
("Relu", "GlobalAveragePool"): activation_pattern_fn,
"Relu": activation_pattern_fn,
}
with custom_pattern_scope(quantization_pattern_map):
model_quantized = quantize(onnx_model, samples=x_test)
# Evaluate quantized model performance
accuracy = evaluate_onnx_model(model_quantized)
print(f'Quantized model accuracy: {100 * accuracy:.2f} %')
Calibrating with 10/10.0 samples
Setting Akida pads in first convolution...
Kernel updated in resnetv17_stage2_conv3_fwd from 1x1 to 3x3.
Kernel updated in resnetv17_stage2_conv0_fwd from 1x1 to 3x3.
Kernel updated in resnetv17_stage3_conv3_fwd from 1x1 to 3x3.
Kernel updated in resnetv17_stage3_conv0_fwd from 1x1 to 3x3.
Kernel updated in resnetv17_stage4_conv3_fwd from 1x1 to 3x3.
Kernel updated in resnetv17_stage4_conv0_fwd from 1x1 to 3x3.
Quantized model accuracy: 100.00 %
3.2. Successful Conversion
Time to check conversion again
akida_model = convert(model_quantized)
Great - the model is now both quantized successfully, and can be entirely converted for acceleration on Akida. To check its performance, we need to bear in mind that
images must be numpy-raw, with an 8-bit unsigned integer data type and
the channel dimension must be in the last dimension.
# Evaluate performance
akida_accuracy = akida_model.evaluate(x_test_raw, labels_test)
print(f'Akida model accuracy: {100 * akida_accuracy:.2f} %')
Akida model accuracy: 100.00 %
3.3. Performance on the full ImageNet validation set
The table below summarizes the obtained accuracy at the various stages using the full ImageNet dataset. Note that forcing pads on the first layer decreases the performance of the model by 0.445% - as noted, that change could be rendered lossless by padding the input image prior to sending instead.
Float accuracy (before Akida adaptation) |
Float accuracy |
Quantized accuracy |
Akida accuracy |
---|---|---|---|
74.368 |
73.918 |
73.590 |
73.620 |
Note
The images shown in this tutorial are produced through Netron.
Total running time of the script: (0 minutes 38.194 seconds)