Source code for akida_models.model_io

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Input/output on models.
"""

import warnings
import os
import numpy as np
from pathlib import Path
from posixpath import join as urljoin

from onnx import load_model as load_onnx_model

from cnn2snn import load_quantized_model, get_akida_version, AkidaVersion

from quantizeml.models.utils import apply_weights_to_model
import akida


[docs] def load_model(model_path, custom_layers=None, compile_model=True): """ Loads an Onnx or Keras or quantized model. An error is raised if the provided model extension is not supported. Args: model_path (str): path of the model to load. custom_layers (dict, optional): custom layers to add to the Keras model. Defaults to None. compile_model (bool, optional): whether to compile the Keras model. Defaults to True. Returns: keras.models.Model or onnx.ModelProto or Akida.Model : Loaded model. Raises: ValueError: if the model could not be loaded using Keras and ONNX loaders. """ _, model_extension = os.path.splitext(model_path.lower()) if model_extension == '.h5': model = load_quantized_model(model_path, custom_objects=custom_layers, compile_model=compile_model) elif model_extension == '.onnx': model = load_onnx_model(model_path) elif model_extension == '.fbz': model = akida.Model(model_path) else: raise ValueError( f"Unsupported model extension: '{model_extension}'. " f"Expected model with extension(s): {['h5', 'onnx', 'fbz']}" ) return model
[docs] def load_weights(model, weights_path): """Loads weights from a npz file and apply it to a model. Go through the dictionary of weights of the npz file, find the corresponding variable in the model and partially load its weights. Args: model (keras.Model): the model to update weights_path (str): the path of the npz file to load """ # Check the npz file validity path = Path(weights_path) if not path.is_file(): raise ValueError(f"File `{weights_path}` not found.") # Open the npz file weights_dict = np.load(weights_path) # Apply the weights to the model apply_weights_to_model(model, weights_dict)
[docs] def save_weights(model, weights_path): """Save model weights on an npz file. Takes a model and save the weights of all its layers into an npz file. Args: model (keras.Model): the model to save its weights weights_path (str): the path of the npz file to save """ weights_dict = {} for var in model.variables: weights_dict[var.name] = var np.savez(weights_path, **weights_dict)
[docs] def get_model_path(subdir="", model_name_v1=None, file_hash_v1=None, model_name_v2=None, file_hash_v2=None): """Selects the model file on the server depending on the AkidaVersion. The model path, model name and its hash depends on the Akida version context. Args: subdir (str, optional): the subdirectory where the model is on the data server. Defaults to "". model_name_v1 (str, optional): the model v1 name. Defaults to None. file_hash_v1 (str, optional): the model file v1 hash. Defaults to None. model_name_v2 (str, optional): the model v2 name. Defaults to None. file_hash_v2 (str, optional): the model file v2 hash. Defaults to None. Returns: str, str, str: the model path, model name and file hash. """ assert get_akida_version() in [AkidaVersion.v1, AkidaVersion.v2] # To guard against parameter usage errors. For a same version, both parameters should be used # or stayed to None. assert type(model_name_v1) == type(file_hash_v1), "All v1 parameters should be used" assert type(model_name_v2) == type(file_hash_v2), "All v2 parameters should be used" if get_akida_version() == AkidaVersion.v1: if not model_name_v1: raise ValueError('Requested model is not available for Akida v1.') warnings.warn(f'Model {model_name_v1} has been trained with akida_models 1.1.10 which is ' 'the last version supporting 1.0 models training') model_base_folder = 'https://data.brainchip.com/models/AkidaV1/' model_name = model_name_v1 file_hash = file_hash_v1 else: if not model_name_v2: raise ValueError('Requested model is not available for Akida v2.') model_base_folder = 'https://data.brainchip.com/models/AkidaV2/' model_name = model_name_v2 file_hash = file_hash_v2 # build the full path model_path = urljoin(model_base_folder, subdir, model_name) return model_path, model_name, file_hash