Note
Go to the end to download the full example code
Segmentation tutorial
This example demonstrates image segmentation with an Akida-compatible model as illustrated through person segmentation using the Portrait128 dataset.
Using pre-trained models for quick runtime, this example shows the evolution of model performance for a trained keras floating point model, a keras quantized and Quantization Aware Trained (QAT) model, and an Akida-converted model. Notice that the performance of the original keras floating point model is maintained throughout the model conversion flow.
1. Load the dataset
import os
import numpy as np
from akida_models import fetch_file
# Download validation set from Brainchip data server, it contains 10% of the original dataset
data_path = fetch_file(fname="val.tar.gz",
origin="https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz",
cache_subdir=os.path.join("datasets", "portrait128"),
extract=True)
data_dir = os.path.join(os.path.dirname(data_path), "val")
x_val = np.load(os.path.join(data_dir, "val_img.npy"))
y_val = np.load(os.path.join(data_dir, "val_msk.npy")).astype('uint8')
batch_size = 32
steps = x_val.shape[0] // 32
# Visualize some data
import matplotlib.pyplot as plt
id = np.random.randint(0, x_val.shape[0])
fig, axs = plt.subplots(3, 3, constrained_layout=True)
for col in range(3):
axs[0, col].imshow(x_val[id + col] / 255.)
axs[0, col].axis('off')
axs[1, col].imshow(1 - y_val[id + col], cmap='Greys')
axs[1, col].axis('off')
axs[2, col].imshow(x_val[id + col] / 255. * y_val[id + col])
axs[2, col].axis('off')
fig.suptitle('Image, mask and masked image', fontsize=10)
plt.show()
Downloading data from https://data.brainchip.com/dataset-mirror/portrait128/val.tar.gz.
0/267313385 [..............................] - ETA: 0s
122880/267313385 [..............................] - ETA: 1:51
720896/267313385 [..............................] - ETA: 37s
1556480/267313385 [..............................] - ETA: 26s
2531328/267313385 [..............................] - ETA: 21s
3489792/267313385 [..............................] - ETA: 19s
4628480/267313385 [..............................] - ETA: 17s
5890048/267313385 [..............................] - ETA: 15s
7143424/267313385 [..............................] - ETA: 14s
8544256/267313385 [..............................] - ETA: 13s
10035200/267313385 [>.............................] - ETA: 12s
11468800/267313385 [>.............................] - ETA: 12s
13107200/267313385 [>.............................] - ETA: 11s
14680064/267313385 [>.............................] - ETA: 11s
16211968/267313385 [>.............................] - ETA: 10s
17735680/267313385 [>.............................] - ETA: 10s
19259392/267313385 [=>............................] - ETA: 10s
20725760/267313385 [=>............................] - ETA: 10s
22183936/267313385 [=>............................] - ETA: 10s
23748608/267313385 [=>............................] - ETA: 9s
25288704/267313385 [=>............................] - ETA: 9s
26828800/267313385 [==>...........................] - ETA: 9s
28434432/267313385 [==>...........................] - ETA: 9s
30031872/267313385 [==>...........................] - ETA: 9s
31612928/267313385 [==>...........................] - ETA: 9s
33153024/267313385 [==>...........................] - ETA: 8s
34725888/267313385 [==>...........................] - ETA: 8s
36331520/267313385 [===>..........................] - ETA: 8s
37847040/267313385 [===>..........................] - ETA: 8s
39550976/267313385 [===>..........................] - ETA: 8s
41353216/267313385 [===>..........................] - ETA: 8s
42893312/267313385 [===>..........................] - ETA: 8s
44638208/267313385 [====>.........................] - ETA: 8s
46309376/267313385 [====>.........................] - ETA: 7s
47816704/267313385 [====>.........................] - ETA: 7s
49602560/267313385 [====>.........................] - ETA: 7s
50962432/267313385 [====>.........................] - ETA: 7s
52682752/267313385 [====>.........................] - ETA: 7s
53952512/267313385 [=====>........................] - ETA: 7s
55574528/267313385 [=====>........................] - ETA: 7s
56918016/267313385 [=====>........................] - ETA: 7s
58359808/267313385 [=====>........................] - ETA: 7s
59736064/267313385 [=====>........................] - ETA: 7s
61308928/267313385 [=====>........................] - ETA: 7s
62840832/267313385 [======>.......................] - ETA: 7s
64331776/267313385 [======>.......................] - ETA: 7s
65855488/267313385 [======>.......................] - ETA: 7s
67330048/267313385 [======>.......................] - ETA: 7s
68902912/267313385 [======>.......................] - ETA: 6s
70344704/267313385 [======>.......................] - ETA: 6s
72032256/267313385 [=======>......................] - ETA: 6s
73523200/267313385 [=======>......................] - ETA: 6s
75210752/267313385 [=======>......................] - ETA: 6s
76603392/267313385 [=======>......................] - ETA: 6s
78159872/267313385 [=======>......................] - ETA: 6s
79437824/267313385 [=======>......................] - ETA: 6s
80896000/267313385 [========>.....................] - ETA: 6s
82239488/267313385 [========>.....................] - ETA: 6s
83599360/267313385 [========>.....................] - ETA: 6s
85106688/267313385 [========>.....................] - ETA: 6s
86351872/267313385 [========>.....................] - ETA: 6s
87842816/267313385 [========>.....................] - ETA: 6s
89088000/267313385 [========>.....................] - ETA: 6s
90611712/267313385 [=========>....................] - ETA: 6s
91971584/267313385 [=========>....................] - ETA: 6s
93364224/267313385 [=========>....................] - ETA: 6s
94789632/267313385 [=========>....................] - ETA: 6s
96264192/267313385 [=========>....................] - ETA: 6s
97837056/267313385 [=========>....................] - ETA: 5s
99278848/267313385 [==========>...................] - ETA: 5s
100950016/267313385 [==========>...................] - ETA: 5s
102391808/267313385 [==========>...................] - ETA: 5s
103997440/267313385 [==========>...................] - ETA: 5s
105308160/267313385 [==========>...................] - ETA: 5s
106831872/267313385 [==========>...................] - ETA: 5s
108175360/267313385 [===========>..................] - ETA: 5s
109699072/267313385 [===========>..................] - ETA: 5s
111140864/267313385 [===========>..................] - ETA: 5s
112533504/267313385 [===========>..................] - ETA: 5s
114040832/267313385 [===========>..................] - ETA: 5s
115499008/267313385 [===========>..................] - ETA: 5s
117006336/267313385 [============>.................] - ETA: 5s
118431744/267313385 [============>.................] - ETA: 5s
120020992/267313385 [============>.................] - ETA: 5s
121454592/267313385 [============>.................] - ETA: 5s
122937344/267313385 [============>.................] - ETA: 5s
124379136/267313385 [============>.................] - ETA: 4s
125870080/267313385 [=============>................] - ETA: 4s
127295488/267313385 [=============>................] - ETA: 4s
128753664/267313385 [=============>................] - ETA: 4s
130277376/267313385 [=============>................] - ETA: 4s
131620864/267313385 [=============>................] - ETA: 4s
133128192/267313385 [=============>................] - ETA: 4s
134553600/267313385 [==============>...............] - ETA: 4s
136028160/267313385 [==============>...............] - ETA: 4s
137453568/267313385 [==============>...............] - ETA: 4s
138944512/267313385 [==============>...............] - ETA: 4s
140386304/267313385 [==============>...............] - ETA: 4s
141844480/267313385 [==============>...............] - ETA: 4s
143286272/267313385 [===============>..............] - ETA: 4s
144809984/267313385 [===============>..............] - ETA: 4s
146268160/267313385 [===============>..............] - ETA: 4s
147824640/267313385 [===============>..............] - ETA: 4s
149233664/267313385 [===============>..............] - ETA: 4s
150691840/267313385 [===============>..............] - ETA: 4s
152199168/267313385 [================>.............] - ETA: 4s
153739264/267313385 [================>.............] - ETA: 3s
155148288/267313385 [================>.............] - ETA: 3s
156737536/267313385 [================>.............] - ETA: 3s
158146560/267313385 [================>.............] - ETA: 3s
159703040/267313385 [================>.............] - ETA: 3s
161193984/267313385 [=================>............] - ETA: 3s
162799616/267313385 [=================>............] - ETA: 3s
164306944/267313385 [=================>............] - ETA: 3s
165879808/267313385 [=================>............] - ETA: 3s
167256064/267313385 [=================>............] - ETA: 3s
168927232/267313385 [=================>............] - ETA: 3s
170336256/267313385 [==================>...........] - ETA: 3s
171827200/267313385 [==================>...........] - ETA: 3s
173195264/267313385 [==================>...........] - ETA: 3s
174612480/267313385 [==================>...........] - ETA: 3s
176070656/267313385 [==================>...........] - ETA: 3s
177446912/267313385 [==================>...........] - ETA: 3s
178987008/267313385 [===================>..........] - ETA: 3s
180330496/267313385 [===================>..........] - ETA: 3s
181837824/267313385 [===================>..........] - ETA: 2s
183230464/267313385 [===================>..........] - ETA: 2s
184688640/267313385 [===================>..........] - ETA: 2s
186048512/267313385 [===================>..........] - ETA: 2s
187506688/267313385 [====================>.........] - ETA: 2s
188964864/267313385 [====================>.........] - ETA: 2s
190308352/267313385 [====================>.........] - ETA: 2s
191873024/267313385 [====================>.........] - ETA: 2s
193306624/267313385 [====================>.........] - ETA: 2s
194912256/267313385 [====================>.........] - ETA: 2s
196255744/267313385 [=====================>........] - ETA: 2s
196878336/267313385 [=====================>........] - ETA: 2s
197828608/267313385 [=====================>........] - ETA: 2s
199188480/267313385 [=====================>........] - ETA: 2s
200826880/267313385 [=====================>........] - ETA: 2s
202596352/267313385 [=====================>........] - ETA: 2s
203726848/267313385 [=====================>........] - ETA: 2s
205447168/267313385 [======================>.......] - ETA: 2s
206495744/267313385 [======================>.......] - ETA: 2s
208052224/267313385 [======================>.......] - ETA: 2s
209346560/267313385 [======================>.......] - ETA: 2s
210821120/267313385 [======================>.......] - ETA: 1s
212312064/267313385 [======================>.......] - ETA: 1s
213590016/267313385 [======================>.......] - ETA: 1s
215343104/267313385 [=======================>......] - ETA: 1s
216555520/267313385 [=======================>......] - ETA: 1s
218243072/267313385 [=======================>......] - ETA: 1s
219488256/267313385 [=======================>......] - ETA: 1s
221126656/267313385 [=======================>......] - ETA: 1s
222453760/267313385 [=======================>......] - ETA: 1s
223764480/267313385 [========================>.....] - ETA: 1s
225779712/267313385 [========================>.....] - ETA: 1s
228745216/267313385 [========================>.....] - ETA: 1s
230645760/267313385 [========================>.....] - ETA: 1s
234364928/267313385 [=========================>....] - ETA: 1s
237527040/267313385 [=========================>....] - ETA: 1s
238362624/267313385 [=========================>....] - ETA: 0s
242081792/267313385 [==========================>...] - ETA: 0s
245014528/267313385 [==========================>...] - ETA: 0s
246931456/267313385 [==========================>...] - ETA: 0s
247996416/267313385 [==========================>...] - ETA: 0s
249552896/267313385 [===========================>..] - ETA: 0s
250904576/267313385 [===========================>..] - ETA: 0s
252452864/267313385 [===========================>..] - ETA: 0s
253730816/267313385 [===========================>..] - ETA: 0s
255172608/267313385 [===========================>..] - ETA: 0s
256696320/267313385 [===========================>..] - ETA: 0s
258007040/267313385 [===========================>..] - ETA: 0s
259448832/267313385 [============================>.] - ETA: 0s
260726784/267313385 [============================>.] - ETA: 0s
262332416/267313385 [============================>.] - ETA: 0s
263921664/267313385 [============================>.] - ETA: 0s
265314304/267313385 [============================>.] - ETA: 0s
266870784/267313385 [============================>.] - ETA: 0s
267313385/267313385 [==============================] - 9s 0us/step
Download complete.
2. Load a pre-trained native Keras model
The model used in this example is AkidaUNet. It has an AkidaNet (0.5) backbone to extract features combined with a succession of separable transposed convolutional blocks to build an image segmentation map. A pre-trained floating point keras model is downloaded to save training time.
Note
The “transposed” convolutional feature is new in Akida 2.0.
The “separable transposed” operation is realized through the combination of a QuantizeML custom DepthwiseConv2DTranspose layer with a standard pointwise convolution.
The performance of the model is evaluated using both pixel accuracy and Binary IoU. The pixel accuracy describes how well the model can predict the segmentation mask pixel by pixel and the Binary IoU takes into account how close the predicted mask is to the ground truth.
from akida_models.model_io import load_model
# Retrieve the model file from Brainchip data server
model_file = fetch_file(fname="akida_unet_portrait128.h5",
origin="https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5",
cache_subdir='models')
# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128.h5.
0/4493968 [..............................] - ETA: 0s
114688/4493968 [..............................] - ETA: 1s
737280/4493968 [===>..........................] - ETA: 0s
1441792/4493968 [========>.....................] - ETA: 0s
2326528/4493968 [==============>...............] - ETA: 0s
3268608/4493968 [====================>.........] - ETA: 0s
4194304/4493968 [==========================>...] - ETA: 0s
4493968/4493968 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 128, 128, 3)] 0
rescaling (Rescaling) (None, 128, 128, 3) 0
conv_0 (Conv2D) (None, 64, 64, 16) 432
conv_0/BN (BatchNormalizati (None, 64, 64, 16) 64
on)
conv_0/relu (ReLU) (None, 64, 64, 16) 0
conv_1 (Conv2D) (None, 64, 64, 32) 4608
conv_1/BN (BatchNormalizati (None, 64, 64, 32) 128
on)
conv_1/relu (ReLU) (None, 64, 64, 32) 0
conv_2 (Conv2D) (None, 32, 32, 64) 18432
conv_2/BN (BatchNormalizati (None, 32, 32, 64) 256
on)
conv_2/relu (ReLU) (None, 32, 32, 64) 0
conv_3 (Conv2D) (None, 32, 32, 64) 36864
conv_3/BN (BatchNormalizati (None, 32, 32, 64) 256
on)
conv_3/relu (ReLU) (None, 32, 32, 64) 0
dw_separable_4 (DepthwiseCo (None, 16, 16, 64) 576
nv2D)
pw_separable_4 (Conv2D) (None, 16, 16, 128) 8192
pw_separable_4/BN (BatchNor (None, 16, 16, 128) 512
malization)
pw_separable_4/relu (ReLU) (None, 16, 16, 128) 0
dw_separable_5 (DepthwiseCo (None, 16, 16, 128) 1152
nv2D)
pw_separable_5 (Conv2D) (None, 16, 16, 128) 16384
pw_separable_5/BN (BatchNor (None, 16, 16, 128) 512
malization)
pw_separable_5/relu (ReLU) (None, 16, 16, 128) 0
dw_separable_6 (DepthwiseCo (None, 8, 8, 128) 1152
nv2D)
pw_separable_6 (Conv2D) (None, 8, 8, 256) 32768
pw_separable_6/BN (BatchNor (None, 8, 8, 256) 1024
malization)
pw_separable_6/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_7 (DepthwiseCo (None, 8, 8, 256) 2304
nv2D)
pw_separable_7 (Conv2D) (None, 8, 8, 256) 65536
pw_separable_7/BN (BatchNor (None, 8, 8, 256) 1024
malization)
pw_separable_7/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_8 (DepthwiseCo (None, 8, 8, 256) 2304
nv2D)
pw_separable_8 (Conv2D) (None, 8, 8, 256) 65536
pw_separable_8/BN (BatchNor (None, 8, 8, 256) 1024
malization)
pw_separable_8/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_9 (DepthwiseCo (None, 8, 8, 256) 2304
nv2D)
pw_separable_9 (Conv2D) (None, 8, 8, 256) 65536
pw_separable_9/BN (BatchNor (None, 8, 8, 256) 1024
malization)
pw_separable_9/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_10 (DepthwiseC (None, 8, 8, 256) 2304
onv2D)
pw_separable_10 (Conv2D) (None, 8, 8, 256) 65536
pw_separable_10/BN (BatchNo (None, 8, 8, 256) 1024
rmalization)
pw_separable_10/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_11 (DepthwiseC (None, 8, 8, 256) 2304
onv2D)
pw_separable_11 (Conv2D) (None, 8, 8, 256) 65536
pw_separable_11/BN (BatchNo (None, 8, 8, 256) 1024
rmalization)
pw_separable_11/relu (ReLU) (None, 8, 8, 256) 0
dw_separable_12 (DepthwiseC (None, 4, 4, 256) 2304
onv2D)
pw_separable_12 (Conv2D) (None, 4, 4, 512) 131072
pw_separable_12/BN (BatchNo (None, 4, 4, 512) 2048
rmalization)
pw_separable_12/relu (ReLU) (None, 4, 4, 512) 0
dw_separable_13 (DepthwiseC (None, 4, 4, 512) 4608
onv2D)
pw_separable_13 (Conv2D) (None, 4, 4, 512) 262144
pw_separable_13/BN (BatchNo (None, 4, 4, 512) 2048
rmalization)
pw_separable_13/relu (ReLU) (None, 4, 4, 512) 0
dw_sepconv_t_0 (DepthwiseCo (None, 8, 8, 512) 5120
nv2DTranspose)
pw_sepconv_t_0 (Conv2D) (None, 8, 8, 256) 131328
pw_sepconv_t_0/BN (BatchNor (None, 8, 8, 256) 1024
malization)
pw_sepconv_t_0/relu (ReLU) (None, 8, 8, 256) 0
dropout (Dropout) (None, 8, 8, 256) 0
dw_sepconv_t_1 (DepthwiseCo (None, 16, 16, 256) 2560
nv2DTranspose)
pw_sepconv_t_1 (Conv2D) (None, 16, 16, 128) 32896
pw_sepconv_t_1/BN (BatchNor (None, 16, 16, 128) 512
malization)
pw_sepconv_t_1/relu (ReLU) (None, 16, 16, 128) 0
dropout_1 (Dropout) (None, 16, 16, 128) 0
dw_sepconv_t_2 (DepthwiseCo (None, 32, 32, 128) 1280
nv2DTranspose)
pw_sepconv_t_2 (Conv2D) (None, 32, 32, 64) 8256
pw_sepconv_t_2/BN (BatchNor (None, 32, 32, 64) 256
malization)
pw_sepconv_t_2/relu (ReLU) (None, 32, 32, 64) 0
dropout_2 (Dropout) (None, 32, 32, 64) 0
dw_sepconv_t_3 (DepthwiseCo (None, 64, 64, 64) 640
nv2DTranspose)
pw_sepconv_t_3 (Conv2D) (None, 64, 64, 32) 2080
pw_sepconv_t_3/BN (BatchNor (None, 64, 64, 32) 128
malization)
pw_sepconv_t_3/relu (ReLU) (None, 64, 64, 32) 0
dropout_3 (Dropout) (None, 64, 64, 32) 0
dw_sepconv_t_4 (DepthwiseCo (None, 128, 128, 32) 320
nv2DTranspose)
pw_sepconv_t_4 (Conv2D) (None, 128, 128, 16) 528
pw_sepconv_t_4/BN (BatchNor (None, 128, 128, 16) 64
malization)
pw_sepconv_t_4/relu (ReLU) (None, 128, 128, 16) 0
dropout_4 (Dropout) (None, 128, 128, 16) 0
head (Conv2D) (None, 128, 128, 1) 17
sigmoid_act (Activation) (None, 128, 128, 1) 0
=================================================================
Total params: 1,058,865
Trainable params: 1,051,889
Non-trainable params: 6,976
_________________________________________________________________
from keras.metrics import BinaryIoU
# Compile the native Keras model (required to evaluate the metrics)
model_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])
# Check Keras model performance
_, biou, acc = model_keras.evaluate(x_val, y_val, steps=steps, verbose=0)
print(f"Keras binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
Keras binary IoU / pixel accuracy: 0.9324 / 96.62%
3. Load a pre-trained quantized Keras model
The next step is to quantize and potentially perform Quantize Aware Training (QAT) on the Keras model from the previous step. After the Keras model is quantized to 8-bits for all weights and activations, QAT is used to maintain the performance of the quantized model. Again, a pre-trained model is downloaded to save runtime.
from akida_models import akida_unet_portrait128_pretrained
# Load the pre-trained quantized model
model_quantized_keras = akida_unet_portrait128_pretrained()
model_quantized_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/akida_unet/akida_unet_portrait128_i8_w8_a8.h5.
0/4520400 [..............................] - ETA: 0s
114688/4520400 [..............................] - ETA: 2s
811008/4520400 [====>.........................] - ETA: 0s
2080768/4520400 [============>.................] - ETA: 0s
3620864/4520400 [=======================>......] - ETA: 0s
4014080/4520400 [=========================>....] - ETA: 0s
4520400/4520400 [==============================] - 0s 0us/step
Download complete.
Model: "akida_unet"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input (InputLayer) [(None, 128, 128, 3)] 0
rescaling (QuantizedRescali (None, 128, 128, 3) 0
ng)
conv_0 (QuantizedConv2D) (None, 64, 64, 16) 448
conv_0/relu (QuantizedReLU) (None, 64, 64, 16) 32
conv_1 (QuantizedConv2D) (None, 64, 64, 32) 4640
conv_1/relu (QuantizedReLU) (None, 64, 64, 32) 64
conv_2 (QuantizedConv2D) (None, 32, 32, 64) 18496
conv_2/relu (QuantizedReLU) (None, 32, 32, 64) 128
conv_3 (QuantizedConv2D) (None, 32, 32, 64) 36928
conv_3/relu (QuantizedReLU) (None, 32, 32, 64) 128
dw_separable_4 (QuantizedDe (None, 16, 16, 64) 704
pthwiseConv2D)
pw_separable_4 (QuantizedCo (None, 16, 16, 128) 8320
nv2D)
pw_separable_4/relu (Quanti (None, 16, 16, 128) 256
zedReLU)
dw_separable_5 (QuantizedDe (None, 16, 16, 128) 1408
pthwiseConv2D)
pw_separable_5 (QuantizedCo (None, 16, 16, 128) 16512
nv2D)
pw_separable_5/relu (Quanti (None, 16, 16, 128) 256
zedReLU)
dw_separable_6 (QuantizedDe (None, 8, 8, 128) 1408
pthwiseConv2D)
pw_separable_6 (QuantizedCo (None, 8, 8, 256) 33024
nv2D)
pw_separable_6/relu (Quanti (None, 8, 8, 256) 512
zedReLU)
dw_separable_7 (QuantizedDe (None, 8, 8, 256) 2816
pthwiseConv2D)
pw_separable_7 (QuantizedCo (None, 8, 8, 256) 65792
nv2D)
pw_separable_7/relu (Quanti (None, 8, 8, 256) 512
zedReLU)
dw_separable_8 (QuantizedDe (None, 8, 8, 256) 2816
pthwiseConv2D)
pw_separable_8 (QuantizedCo (None, 8, 8, 256) 65792
nv2D)
pw_separable_8/relu (Quanti (None, 8, 8, 256) 512
zedReLU)
dw_separable_9 (QuantizedDe (None, 8, 8, 256) 2816
pthwiseConv2D)
pw_separable_9 (QuantizedCo (None, 8, 8, 256) 65792
nv2D)
pw_separable_9/relu (Quanti (None, 8, 8, 256) 512
zedReLU)
dw_separable_10 (QuantizedD (None, 8, 8, 256) 2816
epthwiseConv2D)
pw_separable_10 (QuantizedC (None, 8, 8, 256) 65792
onv2D)
pw_separable_10/relu (Quant (None, 8, 8, 256) 512
izedReLU)
dw_separable_11 (QuantizedD (None, 8, 8, 256) 2816
epthwiseConv2D)
pw_separable_11 (QuantizedC (None, 8, 8, 256) 65792
onv2D)
pw_separable_11/relu (Quant (None, 8, 8, 256) 512
izedReLU)
dw_separable_12 (QuantizedD (None, 4, 4, 256) 2816
epthwiseConv2D)
pw_separable_12 (QuantizedC (None, 4, 4, 512) 131584
onv2D)
pw_separable_12/relu (Quant (None, 4, 4, 512) 1024
izedReLU)
dw_separable_13 (QuantizedD (None, 4, 4, 512) 5632
epthwiseConv2D)
pw_separable_13 (QuantizedC (None, 4, 4, 512) 262656
onv2D)
pw_separable_13/relu (Quant (None, 4, 4, 512) 1024
izedReLU)
dw_sepconv_t_0 (QuantizedDe (None, 8, 8, 512) 6144
pthwiseConv2DTranspose)
pw_sepconv_t_0 (QuantizedCo (None, 8, 8, 256) 131328
nv2D)
pw_sepconv_t_0/relu (Quanti (None, 8, 8, 256) 512
zedReLU)
dropout (QuantizedDropout) (None, 8, 8, 256) 0
dw_sepconv_t_1 (QuantizedDe (None, 16, 16, 256) 3072
pthwiseConv2DTranspose)
pw_sepconv_t_1 (QuantizedCo (None, 16, 16, 128) 32896
nv2D)
pw_sepconv_t_1/relu (Quanti (None, 16, 16, 128) 256
zedReLU)
dropout_1 (QuantizedDropout (None, 16, 16, 128) 0
)
dw_sepconv_t_2 (QuantizedDe (None, 32, 32, 128) 1536
pthwiseConv2DTranspose)
pw_sepconv_t_2 (QuantizedCo (None, 32, 32, 64) 8256
nv2D)
pw_sepconv_t_2/relu (Quanti (None, 32, 32, 64) 128
zedReLU)
dropout_2 (QuantizedDropout (None, 32, 32, 64) 0
)
dw_sepconv_t_3 (QuantizedDe (None, 64, 64, 64) 768
pthwiseConv2DTranspose)
pw_sepconv_t_3 (QuantizedCo (None, 64, 64, 32) 2080
nv2D)
pw_sepconv_t_3/relu (Quanti (None, 64, 64, 32) 64
zedReLU)
dropout_3 (QuantizedDropout (None, 64, 64, 32) 0
)
dw_sepconv_t_4 (QuantizedDe (None, 128, 128, 32) 384
pthwiseConv2DTranspose)
pw_sepconv_t_4 (QuantizedCo (None, 128, 128, 16) 528
nv2D)
pw_sepconv_t_4/relu (Quanti (None, 128, 128, 16) 32
zedReLU)
dropout_4 (QuantizedDropout (None, 128, 128, 16) 0
)
head (QuantizedConv2D) (None, 128, 128, 1) 17
head/dequantizer (Dequantiz (None, 128, 128, 1) 0
er)
sigmoid_act (Activation) (None, 128, 128, 1) 0
=================================================================
Total params: 1,061,601
Trainable params: 1,047,905
Non-trainable params: 13,696
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the metrics)
model_quantized_keras.compile(loss='binary_crossentropy', metrics=[BinaryIoU(), 'accuracy'])
# Check Keras model performance
_, biou, acc = model_quantized_keras.evaluate(x_val, y_val, steps=steps, verbose=0)
print(f"Keras quantized binary IoU / pixel accuracy: {biou:.4f} / {100*acc:.2f}%")
Keras quantized binary IoU / pixel accuracy: 0.9319 / 96.59%
4. Conversion to Akida
Finally, the quantized Keras model from the previous step is converted into an Akida model and its performance is evaluated. Note that the original performance of the keras floating point model is maintained throughout the conversion process in this example.
from cnn2snn import convert
# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
/usr/local/lib/python3.8/dist-packages/cnn2snn/quantizeml/blocks.py:160: UserWarning: Conversion stops at layer head because of a dequantizer. The end of the model is ignored:
___________________________________________________
Layer (type)
===================================================
sigmoid_act (Activation)
===================================================
warnings.warn("Conversion stops" + stop_layer_msg + " because of a dequantizer. "
Model Summary
_________________________________________________
Input shape Output shape Sequences Layers
=================================================
[128, 128, 3] [128, 128, 1] 1 36
_________________________________________________
_____________________________________________________________________________
Layer (type) Output shape Kernel shape
=================== SW/conv_0-head/dequantizer (Software) ===================
conv_0 (InputConv2D) [64, 64, 16] (3, 3, 3, 16)
_____________________________________________________________________________
conv_1 (Conv2D) [64, 64, 32] (3, 3, 16, 32)
_____________________________________________________________________________
conv_2 (Conv2D) [32, 32, 64] (3, 3, 32, 64)
_____________________________________________________________________________
conv_3 (Conv2D) [32, 32, 64] (3, 3, 64, 64)
_____________________________________________________________________________
dw_separable_4 (DepthwiseConv2D) [16, 16, 64] (3, 3, 64, 1)
_____________________________________________________________________________
pw_separable_4 (Conv2D) [16, 16, 128] (1, 1, 64, 128)
_____________________________________________________________________________
dw_separable_5 (DepthwiseConv2D) [16, 16, 128] (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_5 (Conv2D) [16, 16, 128] (1, 1, 128, 128)
_____________________________________________________________________________
dw_separable_6 (DepthwiseConv2D) [8, 8, 128] (3, 3, 128, 1)
_____________________________________________________________________________
pw_separable_6 (Conv2D) [8, 8, 256] (1, 1, 128, 256)
_____________________________________________________________________________
dw_separable_7 (DepthwiseConv2D) [8, 8, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_7 (Conv2D) [8, 8, 256] (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_8 (DepthwiseConv2D) [8, 8, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_8 (Conv2D) [8, 8, 256] (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_9 (DepthwiseConv2D) [8, 8, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_9 (Conv2D) [8, 8, 256] (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_10 (DepthwiseConv2D) [8, 8, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_10 (Conv2D) [8, 8, 256] (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_11 (DepthwiseConv2D) [8, 8, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_11 (Conv2D) [8, 8, 256] (1, 1, 256, 256)
_____________________________________________________________________________
dw_separable_12 (DepthwiseConv2D) [4, 4, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_separable_12 (Conv2D) [4, 4, 512] (1, 1, 256, 512)
_____________________________________________________________________________
dw_separable_13 (DepthwiseConv2D) [4, 4, 512] (3, 3, 512, 1)
_____________________________________________________________________________
pw_separable_13 (Conv2D) [4, 4, 512] (1, 1, 512, 512)
_____________________________________________________________________________
dw_sepconv_t_0 (DepthwiseConv2DTranspose) [8, 8, 512] (3, 3, 512, 1)
_____________________________________________________________________________
pw_sepconv_t_0 (Conv2D) [8, 8, 256] (1, 1, 512, 256)
_____________________________________________________________________________
dw_sepconv_t_1 (DepthwiseConv2DTranspose) [16, 16, 256] (3, 3, 256, 1)
_____________________________________________________________________________
pw_sepconv_t_1 (Conv2D) [16, 16, 128] (1, 1, 256, 128)
_____________________________________________________________________________
dw_sepconv_t_2 (DepthwiseConv2DTranspose) [32, 32, 128] (3, 3, 128, 1)
_____________________________________________________________________________
pw_sepconv_t_2 (Conv2D) [32, 32, 64] (1, 1, 128, 64)
_____________________________________________________________________________
dw_sepconv_t_3 (DepthwiseConv2DTranspose) [64, 64, 64] (3, 3, 64, 1)
_____________________________________________________________________________
pw_sepconv_t_3 (Conv2D) [64, 64, 32] (1, 1, 64, 32)
_____________________________________________________________________________
dw_sepconv_t_4 (DepthwiseConv2DTranspose) [128, 128, 32] (3, 3, 32, 1)
_____________________________________________________________________________
pw_sepconv_t_4 (Conv2D) [128, 128, 16] (1, 1, 32, 16)
_____________________________________________________________________________
head (Conv2D) [128, 128, 1] (1, 1, 16, 1)
_____________________________________________________________________________
head/dequantizer (Dequantizer) [128, 128, 1] N/A
_____________________________________________________________________________
import tensorflow as tf
# Check Akida model performance
labels, pots = None, None
for s in range(steps):
batch = x_val[s * batch_size: (s + 1) * batch_size, :]
label_batch = y_val[s * batch_size: (s + 1) * batch_size, :]
pots_batch = model_akida.predict(batch.astype('uint8'))
if labels is None:
labels = label_batch
pots = pots_batch
else:
labels = np.concatenate((labels, label_batch))
pots = np.concatenate((pots, pots_batch))
preds = tf.keras.activations.sigmoid(pots)
m_binary_iou = tf.keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
m_binary_iou.update_state(labels, preds)
binary_iou = m_binary_iou.result().numpy()
m_accuracy = tf.keras.metrics.Accuracy()
m_accuracy.update_state(labels, preds > 0.5)
accuracy = m_accuracy.result().numpy()
print(f"Akida binary IoU / pixel accuracy: {binary_iou:.4f} / {100*accuracy:.2f}%")
# For non-regression purpose
assert binary_iou > 0.9
Akida binary IoU / pixel accuracy: 0.9308 / 96.59%
5. Segment a single image
For visualization of the person segmentation performed by the Akida model, display a single image along with the segmentation produced by the original floating point model and the ground truth segmentation.
import matplotlib.pyplot as plt
# Estimate age on a random single image and display Keras and Akida outputs
sample = np.expand_dims(x_val[id, :], 0)
keras_out = model_keras(sample)
akida_out = tf.keras.activations.sigmoid(model_akida.forward(sample.astype('uint8')))
fig, axs = plt.subplots(1, 3, constrained_layout=True)
axs[0].imshow(keras_out[0] * sample[0] / 255.)
axs[0].set_title('Keras segmentation', fontsize=10)
axs[0].axis('off')
axs[1].imshow(akida_out[0] * sample[0] / 255.)
axs[1].set_title('Akida segmentation', fontsize=10)
axs[1].axis('off')
axs[2].imshow(y_val[id] * sample[0] / 255.)
axs[2].set_title('Expected segmentation', fontsize=10)
axs[2].axis('off')
plt.show()
Total running time of the script: (1 minutes 46.168 seconds)