Source code for akida_models.transformers.model_deit

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
ViT model definition.
Inspired from https://github.com/faustomorales/vit-keras/blob/master/vit_keras/vit.py.
"""
import keras

from quantizeml.layers import AddPositionEmbs, ClassToken, Add, ExtractToken
from quantizeml.models import load_model

from .model_vit import CONFIG_TI, CONFIG_S, CONFIG_B
from ..imagenet.imagenet_utils import IMAGENET_MEAN, IMAGENET_STD
from ..layer_blocks import norm_to_layer, transformer_block
from ..utils import fetch_file
from ..model_io import get_model_path


[docs]def deit_imagenet(input_shape, num_blocks, hidden_size, num_heads, name, mlp_dim, patch_size=16, classes=1000, dropout=0.1, include_top=True, distilled=False, norm='LN', last_norm='LN', softmax='softmax', act="GeLU"): """Build a DeiT model. Args: input_shape (tuple): image shape tuple num_blocks (int): the number of transformer blocks to use. hidden_size (int): the number of filters to use num_heads (int): the number of transformer heads name (str): the model name mlp_dim (int): the number of dimensions for the MLP output in the transformers. patch_size (int, optional): the size of each patch (must fit evenly in image size). Defaults to 16. classes (int, optional): number of classes to classify images into, only to be specified if `include_top` is True. Defaults to 1000. dropout (float, optional): fraction of the units to drop for dense layers. Defaults to 0.1. include_top (bool, optional): whether to include the final classifier head. If False, the output will correspond to that of the transformer. Defaults to True. distilled (bool, optional): Build model append a distilled token. Defaults to False. norm (str, optional): string that values in ['LN', 'GN1', 'BN', 'LMN'] and that allows to choose from LayerNormalization, GroupNormalization(groups=1, ...), BatchNormalization or LayerMadNormalization layers respectively in the model. Defaults to 'LN'. last_norm (str, optional): string that values in ['LN', 'BN'] and that allows to choose from LayerNormalization or BatchNormalization in the classifier network. Defaults to 'LN'. softmax (str, optional): string with values in ['softmax', 'softmax2'] that allows to choose between softmax and softmax2 in MHA. Defaults to 'softmax'. act (str, optional): string that values in ['GeLU', 'ReLUx', 'swish'] and that allows to choose from GeLU, ReLUx or swish activation in MLP block. Defaults to 'GeLU'. """ assert ((input_shape[0] % patch_size == 0) and (input_shape[1] % patch_size == 0)), "image size must be a multiple of patch_size" if last_norm not in ('LN', 'BN'): raise NotImplementedError("last_norm should be in ['LN', 'BN']" f"but received {norm}.") # Normalize image adding rescaling layer x = keras.layers.Input(shape=input_shape, name="input") scale = list(1.0 / 255 / std for std in IMAGENET_STD) offset = list(-mean / std for mean, std in zip(IMAGENET_MEAN, IMAGENET_STD)) y = keras.layers.Rescaling(scale=scale, offset=offset, name="Rescale")(x) # Build model y = keras.layers.Conv2D( filters=hidden_size, kernel_size=patch_size, strides=patch_size, padding="valid", name="Embedding", kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), bias_initializer="zeros", )(y) y = keras.layers.Reshape((y.shape[1] * y.shape[2], hidden_size))(y) if distilled: y = ClassToken(name="DistToken")(y) y = ClassToken(name="ClassToken")(y) y = AddPositionEmbs(name="Transformer/PosEmbed")(y) for n in range(num_blocks): y, _ = transformer_block( y, num_heads=num_heads, hidden_size=hidden_size, mlp_dim=mlp_dim, dropout=dropout, name=f"Transformer/EncoderBlock_{n}", norm=norm, softmax=softmax, mlp_act=act, ) # Include classification head if include_top: yt = norm_to_layer(last_norm)( epsilon=1e-6, name="Transformer/EncoderNorm")(y) y = ExtractToken(token=0, name="ExtractToken")(yt) y = keras.layers.Dense(classes, name="Head")(y) if distilled: yd = ExtractToken(token=1, name="ExtractToken_Dist")(yt) yd = keras.layers.Dense(classes, name="DistHead")(yd) y = Add(name="Add", average=True)([y, yd]) # Add distilled flag model = keras.models.Model(inputs=x, outputs=y, name=name) model.isdistilled = distilled return model
[docs]def deit_ti16(input_shape=(224, 224, 3), classes=1000, distilled=False, norm='LN', last_norm='LN', softmax='softmax', act='GeLU', include_top=True): """Build DeiT-Tiny. Args: input_shape (tuple, optional): input shape. Defaults to (224, 224, 3). classes (int, optional): number of classes. Defaults to 1000. distilled (bool, optional): build model appending a distilled token. Defaults to False. norm (str, optional): string that values in ['LN', 'GN1', 'BN', 'LMN'] and that allows to choose from LayerNormalization, GroupNormalization(groups=1, ...), BatchNormalization or LayerMadNormalization layers respectively in the model. Defaults to 'LN'. last_norm (str, optional): string that values in ['LN', 'BN'] and that allows to choose from LayerNormalization or BatchNormalization in the classifier network. Defaults to 'LN'. softmax (str, optional): string with values in ['softmax', 'softmax2'] that allows to choose between softmax and softmax2 in attention block. Defaults to 'softmax'. act (str, optional): string that values in ['GeLU', 'ReLUx', 'swish'] and that allows to choose from GeLU, ReLUx or swish activation inside MLP. Defaults to 'GeLU'. include_top (bool, optional): whether to include the final classifier network. Defaults to True. Returns: keras.Model: the requested model """ return deit_imagenet( name="deit-tiny", input_shape=input_shape, classes=classes, distilled=distilled, norm=norm, last_norm=last_norm, act=act, softmax=softmax, include_top=include_top, **CONFIG_TI, )
[docs]def bc_deit_ti16(input_shape=(224, 224, 3), classes=1000, distilled=False, include_top=True, num_blocks=12): """Build DeiT-Tiny, changing all LN by LMN, using softmax2 and ReLU8. Args: input_shape (tuple, optional): input shape. Defaults to (224, 224, 3). classes (int, optional): number of classes. Defaults to 1000. distilled (bool, optional): build model appending a distilled token. Defaults to False. include_top (bool, optional): whether to include the final classifier network. Defaults to True. num_blocks (int, optional): the number of transformer blocks to use. Defaults to 12. Returns: keras.Model: the requested model """ config_ti = CONFIG_TI.copy() config_ti["num_blocks"] = num_blocks return deit_imagenet( name="deit-tiny", input_shape=input_shape, classes=classes, distilled=distilled, norm="LMN", last_norm="BN", softmax="softmax2", act="ReLU8", include_top=include_top, **config_ti, )
[docs]def bc_deit_dist_ti16_imagenet_pretrained(): """ Helper method to retrieve a `bc_deit_dist_ti16` model that was trained on ImageNet dataset. Returns: keras.Model: a Keras Model instance """ model_name_v2 = 'bc_deit_dist_ti16_224_i8_w8_a8.h5' file_hash_v2 = '4e40e36beb56a926b18cab4c5880aa3cbe768e0e40c9c91abca6063a72763044' model_path, model_name, file_hash = get_model_path("deit", model_name_v2=model_name_v2, file_hash_v2=file_hash_v2) model_path = fetch_file(model_path, fname=model_name, file_hash=file_hash, cache_subdir='models') return load_model(model_path)
[docs]def deit_s16(input_shape=(224, 224, 3), classes=1000, distilled=False, include_top=True): """Build DeiT-Small. Args: input_shape (tuple, optional): input shape. Defaults to (224, 224, 3). classes (int, optional): number of classes. Defaults to 1000. distilled (bool, optional): build model appending a distilled token. Defaults to False. include_top (bool, optional): whether to include the final classifier network. Defaults to True. Returns: keras.Model: the requested model """ return deit_imagenet( name="deit-small", input_shape=input_shape, classes=classes, distilled=distilled, include_top=include_top, **CONFIG_S, )
[docs]def deit_b16(input_shape=(224, 224, 3), classes=1000, distilled=False, include_top=True): """Build DeiT-B16. Args: input_shape (tuple, optional): input shape. Defaults to (224, 224, 3). classes (int, optional): number of classes. Defaults to 1000. distilled (bool, optional): build model appending a distilled token. Defaults to False. include_top (bool, optional): whether to include the final classifier network. Defaults to True. Returns: keras.Model: the requested model """ return deit_imagenet( name="deit-base", input_shape=input_shape, classes=classes, distilled=distilled, include_top=include_top, **CONFIG_B, )