Age estimation (regression) example

This tutorial aims to demonstrate the comparable accuracy of the Akida-compatible model to the traditional Keras model in performing an age estimation task.

It uses the UTKFace dataset, which includes images of faces and age labels, to showcase how well akida compatible model can predict the ages of individuals based on their facial features.

1. Load the UTKFace Dataset

The UTKFace dataset has 20,000+ diverse face images spanning 0 to 116 years. It includes age, gender, ethnicity annotations. This dataset is useful for various tasks like age estimation, face detection, and more.

Load the dataset from Brainchip data server using the load_data helper (decode JPEG images and load the associated labels).

from akida_models.utk_face.preprocessing import load_data

# Load the dataset
x_train, y_train, x_test, y_test = load_data()
Downloading data from https://data.brainchip.com/dataset-mirror/utk_face/UTKFace_preprocessed.tar.gz.

       0/48742400 [..............................] - ETA: 0s
   90112/48742400 [..............................] - ETA: 32s
  475136/48742400 [..............................] - ETA: 11s
  901120/48742400 [..............................] - ETA: 8s 
 1277952/48742400 [..............................] - ETA: 7s
 1654784/48742400 [>.............................] - ETA: 7s
 2048000/48742400 [>.............................] - ETA: 7s
 2441216/48742400 [>.............................] - ETA: 6s
 2826240/48742400 [>.............................] - ETA: 6s
 3235840/48742400 [>.............................] - ETA: 6s
 3629056/48742400 [=>............................] - ETA: 6s
 4038656/48742400 [=>............................] - ETA: 6s
 4448256/48742400 [=>............................] - ETA: 6s
 4841472/48742400 [=>............................] - ETA: 6s
 5283840/48742400 [==>...........................] - ETA: 5s
 5693440/48742400 [==>...........................] - ETA: 5s
 6135808/48742400 [==>...........................] - ETA: 5s
 6561792/48742400 [===>..........................] - ETA: 5s
 7004160/48742400 [===>..........................] - ETA: 5s
 7446528/48742400 [===>..........................] - ETA: 5s
 7905280/48742400 [===>..........................] - ETA: 5s
 8364032/48742400 [====>.........................] - ETA: 5s
 8822784/48742400 [====>.........................] - ETA: 5s
 9297920/48742400 [====>.........................] - ETA: 5s
 9773056/48742400 [=====>........................] - ETA: 4s
10215424/48742400 [=====>........................] - ETA: 4s
10706944/48742400 [=====>........................] - ETA: 4s
11149312/48742400 [=====>........................] - ETA: 4s
11640832/48742400 [======>.......................] - ETA: 4s
12115968/48742400 [======>.......................] - ETA: 4s
12574720/48742400 [======>.......................] - ETA: 4s
13082624/48742400 [=======>......................] - ETA: 4s
13557760/48742400 [=======>......................] - ETA: 4s
14065664/48742400 [=======>......................] - ETA: 4s
14573568/48742400 [=======>......................] - ETA: 4s
15097856/48742400 [========>.....................] - ETA: 4s
15622144/48742400 [========>.....................] - ETA: 3s
16162816/48742400 [========>.....................] - ETA: 3s
16654336/48742400 [=========>....................] - ETA: 3s
17195008/48742400 [=========>....................] - ETA: 3s
17735680/48742400 [=========>....................] - ETA: 3s
18276352/48742400 [==========>...................] - ETA: 3s
18784256/48742400 [==========>...................] - ETA: 3s
19324928/48742400 [==========>...................] - ETA: 3s
19881984/48742400 [===========>..................] - ETA: 3s
20439040/48742400 [===========>..................] - ETA: 3s
20996096/48742400 [===========>..................] - ETA: 3s
21553152/48742400 [============>.................] - ETA: 3s
22126592/48742400 [============>.................] - ETA: 2s
22683648/48742400 [============>.................] - ETA: 2s
23240704/48742400 [=============>................] - ETA: 2s
23797760/48742400 [=============>................] - ETA: 2s
24354816/48742400 [=============>................] - ETA: 2s
24928256/48742400 [==============>...............] - ETA: 2s
25485312/48742400 [==============>...............] - ETA: 2s
26042368/48742400 [===============>..............] - ETA: 2s
26632192/48742400 [===============>..............] - ETA: 2s
27205632/48742400 [===============>..............] - ETA: 2s
27762688/48742400 [================>.............] - ETA: 2s
28319744/48742400 [================>.............] - ETA: 2s
28712960/48742400 [================>.............] - ETA: 2s
29384704/48742400 [=================>............] - ETA: 2s
29810688/48742400 [=================>............] - ETA: 2s
30269440/48742400 [=================>............] - ETA: 1s
30728192/48742400 [=================>............] - ETA: 1s
31186944/48742400 [==================>...........] - ETA: 1s
31662080/48742400 [==================>...........] - ETA: 1s
32137216/48742400 [==================>...........] - ETA: 1s
32595968/48742400 [===================>..........] - ETA: 1s
33087488/48742400 [===================>..........] - ETA: 1s
33579008/48742400 [===================>..........] - ETA: 1s
34004992/48742400 [===================>..........] - ETA: 1s
34512896/48742400 [====================>.........] - ETA: 1s
35020800/48742400 [====================>.........] - ETA: 1s
35512320/48742400 [====================>.........] - ETA: 1s
36036608/48742400 [=====================>........] - ETA: 1s
36544512/48742400 [=====================>........] - ETA: 1s
37068800/48742400 [=====================>........] - ETA: 1s
37593088/48742400 [======================>.......] - ETA: 1s
38117376/48742400 [======================>.......] - ETA: 1s
38658048/48742400 [======================>.......] - ETA: 1s
39182336/48742400 [=======================>......] - ETA: 1s
39723008/48742400 [=======================>......] - ETA: 0s
40280064/48742400 [=======================>......] - ETA: 0s
40820736/48742400 [========================>.....] - ETA: 0s
41361408/48742400 [========================>.....] - ETA: 0s
41902080/48742400 [========================>.....] - ETA: 0s
42442752/48742400 [=========================>....] - ETA: 0s
42999808/48742400 [=========================>....] - ETA: 0s
43556864/48742400 [=========================>....] - ETA: 0s
44113920/48742400 [==========================>...] - ETA: 0s
44670976/48742400 [==========================>...] - ETA: 0s
45244416/48742400 [==========================>...] - ETA: 0s
45801472/48742400 [===========================>..] - ETA: 0s
46358528/48742400 [===========================>..] - ETA: 0s
46915584/48742400 [===========================>..] - ETA: 0s
47472640/48742400 [============================>.] - ETA: 0s
48029696/48742400 [============================>.] - ETA: 0s
48570368/48742400 [============================>.] - ETA: 0s
48742400/48742400 [==============================] - 5s 0us/step

Akida models accept only uint8 tensors as inputs. Use uint8 raw data for Akida performance evaluation.

# For Akida inference, use uint8 raw data
x_test_akida = x_test.astype('uint8')

2. Load a pre-trained native Keras model

The model is a simplified version inspired from VGG architecture. It consists of a succession of convolutional and pooling layers and ends with two dense layers that outputs a single value corresponding to the estimated age.

The performance of the model is evaluated using the “Mean Absolute Error” (MAE). The MAE, used as a metric in regression problem, is calculated as an average of absolute differences between the target values and the predictions. The MAE is a linear score, i.e. all the individual differences are equally weighted in the average.

from tensorflow.keras.utils import get_file
from tensorflow.keras.models import load_model

# Retrieve the model file from the BrainChip data server
model_file = get_file("vgg_utk_face.h5",
                      "https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face.h5",
                      cache_subdir='models')

# Load the native Keras pre-trained model
model_keras = load_model(model_file)
model_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face.h5

   8192/1912544 [..............................] - ETA: 0s
 188416/1912544 [=>............................] - ETA: 0s
 385024/1912544 [=====>........................] - ETA: 0s
1171456/1912544 [=================>............] - ETA: 0s
1703936/1912544 [=========================>....] - ETA: 0s
1912544/1912544 [==============================] - 0s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 32, 32, 3)]       0

 rescaling (Rescaling)       (None, 32, 32, 3)         0

 conv_0 (Conv2D)             (None, 30, 30, 32)        864

 conv_0/BN (BatchNormalizati  (None, 30, 30, 32)       128
 on)

 conv_0/relu (ReLU)          (None, 30, 30, 32)        0

 conv_1 (Conv2D)             (None, 30, 30, 32)        9216

 conv_1/maxpool (MaxPooling2  (None, 15, 15, 32)       0
 D)

 conv_1/BN (BatchNormalizati  (None, 15, 15, 32)       128
 on)

 conv_1/relu (ReLU)          (None, 15, 15, 32)        0

 dropout_3 (Dropout)         (None, 15, 15, 32)        0

 conv_2 (Conv2D)             (None, 15, 15, 64)        18432

 conv_2/BN (BatchNormalizati  (None, 15, 15, 64)       256
 on)

 conv_2/relu (ReLU)          (None, 15, 15, 64)        0

 conv_3 (Conv2D)             (None, 15, 15, 64)        36864

 conv_3/maxpool (MaxPooling2  (None, 8, 8, 64)         0
 D)

 conv_3/BN (BatchNormalizati  (None, 8, 8, 64)         256
 on)

 conv_3/relu (ReLU)          (None, 8, 8, 64)          0

 dropout_4 (Dropout)         (None, 8, 8, 64)          0

 conv_4 (Conv2D)             (None, 8, 8, 84)          48384

 conv_4/BN (BatchNormalizati  (None, 8, 8, 84)         336
 on)

 conv_4/relu (ReLU)          (None, 8, 8, 84)          0

 dropout_5 (Dropout)         (None, 8, 8, 84)          0

 flatten (Flatten)           (None, 5376)              0

 dense_1 (Dense)             (None, 64)                344064

 dense_1/BN (BatchNormalizat  (None, 64)               256
 ion)

 dense_1/relu (ReLU)         (None, 64)                0

 dense_2 (Dense)             (None, 1)                 65

=================================================================
Total params: 459,249
Trainable params: 458,569
Non-trainable params: 680
_________________________________________________________________
# Compile the native Keras model (required to evaluate the MAE)
model_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_keras = model_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_keras))
Keras MAE: 5.8671

3. Load a pre-trained quantized Keras model

The above native Keras model is quantized and fine-tuned (QAT). The first convolutional layer of our model uses 8-bit weights, other layers are quantized using 4-bit weights, all activations are 4-bit.

from akida_models import vgg_utk_face_pretrained

# Load the pre-trained quantized model
model_quantized_keras = vgg_utk_face_pretrained()
model_quantized_keras.summary()
Downloading data from https://data.brainchip.com/models/AkidaV2/vgg/vgg_utk_face_i8_w4_a4.h5.

      0/1908680 [..............................] - ETA: 0s
  98304/1908680 [>.............................] - ETA: 0s
 385024/1908680 [=====>........................] - ETA: 0s
 950272/1908680 [=============>................] - ETA: 0s
1376256/1908680 [====================>.........] - ETA: 0s
1826816/1908680 [===========================>..] - ETA: 0s
1908680/1908680 [==============================] - 0s 0us/step
Model: "vgg_utk_face"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input (InputLayer)          [(None, 32, 32, 3)]       0

 rescaling (QuantizedRescali  (None, 32, 32, 3)        0
 ng)

 conv_0 (QuantizedConv2D)    (None, 30, 30, 32)        896

 conv_0/relu (QuantizedReLU)  (None, 30, 30, 32)       64

 conv_1 (QuantizedConv2D)    (None, 30, 30, 32)        9248

 conv_1/maxpool (QuantizedMa  (None, 15, 15, 32)       0
 xPool2D)

 conv_1/relu (QuantizedReLU)  (None, 15, 15, 32)       64

 dropout_3 (QuantizedDropout  (None, 15, 15, 32)       0
 )

 conv_2 (QuantizedConv2D)    (None, 15, 15, 64)        18496

 conv_2/relu (QuantizedReLU)  (None, 15, 15, 64)       128

 conv_3 (QuantizedConv2D)    (None, 15, 15, 64)        36928

 conv_3/maxpool (QuantizedMa  (None, 8, 8, 64)         0
 xPool2D)

 conv_3/relu (QuantizedReLU)  (None, 8, 8, 64)         128

 dropout_4 (QuantizedDropout  (None, 8, 8, 64)         0
 )

 conv_4 (QuantizedConv2D)    (None, 8, 8, 84)          48468

 conv_4/relu (QuantizedReLU)  (None, 8, 8, 84)         168

 dropout_5 (QuantizedDropout  (None, 8, 8, 84)         0
 )

 flatten (QuantizedFlatten)  (None, 5376)              0

 dense_1 (QuantizedDense)    (None, 64)                344128

 dense_1/relu (QuantizedReLU  (None, 64)               2
 )

 dense_2 (QuantizedDense)    (None, 1)                 65

 dequantizer (Dequantizer)   (None, 1)                 0

=================================================================
Total params: 458,783
Trainable params: 458,229
Non-trainable params: 554
_________________________________________________________________
# Compile the quantized Keras model (required to evaluate the MAE)
model_quantized_keras.compile(optimizer='Adam', loss='mae')

# Check Keras model performance
mae_quant = model_quantized_keras.evaluate(x_test, y_test, verbose=0)

print("Keras MAE: {0:.4f}".format(mae_quant))
Keras MAE: 5.8975

4. Conversion to Akida

The quantized Keras model is now converted into an Akida model. After conversion, we evaluate the performance on the UTKFace dataset.

from cnn2snn import convert

# Convert the model
model_akida = convert(model_quantized_keras)
model_akida.summary()
                Model Summary
______________________________________________
Input shape  Output shape  Sequences  Layers
==============================================
[32, 32, 3]  [1, 1, 1]     1          8
______________________________________________

_________________________________________________________
Layer (type)               Output shape  Kernel shape

============ SW/conv_0-dequantizer (Software) ===========

conv_0 (InputConv2D)       [30, 30, 32]  (3, 3, 3, 32)
_________________________________________________________
conv_1 (Conv2D)            [15, 15, 32]  (3, 3, 32, 32)
_________________________________________________________
conv_2 (Conv2D)            [15, 15, 64]  (3, 3, 32, 64)
_________________________________________________________
conv_3 (Conv2D)            [8, 8, 64]    (3, 3, 64, 64)
_________________________________________________________
conv_4 (Conv2D)            [8, 8, 84]    (3, 3, 64, 84)
_________________________________________________________
dense_1 (Dense2D)          [1, 1, 64]    (5376, 64)
_________________________________________________________
dense_2 (Dense2D)          [1, 1, 1]     (64, 1)
_________________________________________________________
dequantizer (Dequantizer)  [1, 1, 1]     N/A
_________________________________________________________
import numpy as np

# Check Akida model performance
y_akida = model_akida.predict(x_test_akida)

# Compute and display the MAE
mae_akida = np.sum(np.abs(y_test.squeeze() - y_akida.squeeze())) / len(y_test)
print("Akida MAE: {0:.4f}".format(mae_akida))

# For non-regression purposes
assert abs(mae_keras - mae_akida) < 0.5
Akida MAE: 6.0405

5. Estimate age on a single image

Select a random image from the test set for age estimation. Print the Keras model’s age prediction using the model_keras.predict() function. Print the Akida model’s estimated age and the actual age associated with the image.

import matplotlib.pyplot as plt

# Estimate age on a random single image and display Keras and Akida outputs
id = np.random.randint(0, len(y_test) + 1)
age_keras = model_keras.predict(x_test[id:id + 1])

plt.imshow(x_test_akida[id], interpolation='bicubic')
plt.xticks([]), plt.yticks([])
plt.show()

print("Keras estimated age: {0:.1f}".format(age_keras.squeeze()))
print("Akida estimated age: {0:.1f}".format(y_akida[id].squeeze()))
print(f"Actual age: {y_test[id].squeeze()}")
plot 3 regression
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 98ms/step
Keras estimated age: 36.7
Akida estimated age: 36.5
Actual age: 35

Total running time of the script: (0 minutes 26.775 seconds)

Gallery generated by Sphinx-Gallery