Source code for quantizeml.model_io

__all__ = ["load_model", "save_model"]

import os
import keras
import onnx


[docs]def load_model(model_path, custom_layers=None, compile_model=True): """ Loads an Onnx or Keras model. An error is raised if the provided model extension is not supported. Args: model_path (str): path of the model to load. custom_layers (dict, optional): custom layers to add to the Keras model. Defaults to None. compile_model (bool, optional): whether to compile the Keras model. Defaults to True. Returns: keras.models.Model or onnx.ModelProto: Loaded model. Raises: ValueError: if the model could not be loaded using Keras and ONNX loaders. """ _, model_extension = os.path.splitext(model_path.lower()) if model_extension == '.h5': model = keras.models.load_model(model_path, custom_objects=custom_layers, compile=compile_model) elif model_extension == '.onnx': model = onnx.load_model(model_path) else: raise ValueError( f"Unsupported model extension: '{model_extension}'. " f"Expected model with extension(s): {['h5', 'onnx']}" ) return model
[docs]def save_model(model, path): """ Save an ONNX or Keras model into a path. Note extension is overwritten given the model type. Args: model (keras.Model, keras.Sequential or onnx.ModelProto): model to serialize. model_path (str): path to save the model. Returns: str: the path where the model was saved. Raises: ValueError: if the model to save is not a Keras or ONNX model. """ model_name, _ = os.path.splitext(path) if isinstance(model, (keras.Model, keras.Sequential)): model_path_with_ext = model_name + ".h5" model.save(model_path_with_ext, include_optimizer=False) elif isinstance(model, onnx.ModelProto): model_path_with_ext = model_name + ".onnx" onnx.save_model(model, model_path_with_ext) else: raise ValueError(f"Unrecognized {type(model)} model type. Expected a keras or ONNX model.") return model_path_with_ext