#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["custom_pattern_scope"]
from collections import namedtuple
from inspect import signature
from contextlib import contextmanager
from .. import layers as onnx_qlayers
# Define named tuples for QuantizerPattern
QuantizePattern = namedtuple('QuantizerPattern', ['pattern', 'f'])
# List of supported patterns, together with matching function
CUSTOM_PATTERNS_MAP = []
PATTERNS_MAP = [
QuantizePattern(("Conv", "Relu", "GlobalAveragePool"), onnx_qlayers.get_qconv),
QuantizePattern(("Conv", "Relu", "MaxPool"), onnx_qlayers.get_qconv),
QuantizePattern(("Conv", "GlobalAveragePool"), onnx_qlayers.get_qconv),
QuantizePattern(("Conv", "Relu"), onnx_qlayers.get_qconv),
QuantizePattern(("Conv",), onnx_qlayers.get_qconv),
QuantizePattern(("DepthwiseConv", "Relu"), onnx_qlayers.get_qdepthwise),
QuantizePattern(("DepthwiseConv",), onnx_qlayers.get_qdepthwise),
QuantizePattern(("Flatten", "Gemm", "Relu"), onnx_qlayers.get_qgemm),
QuantizePattern(("Flatten", "Gemm"), onnx_qlayers.get_qgemm),
QuantizePattern(("Gemm", "Relu"), onnx_qlayers.get_qgemm),
QuantizePattern(("Gemm",), onnx_qlayers.get_qgemm),
QuantizePattern(("Add",), onnx_qlayers.get_qadd),
]
[docs]@contextmanager
def custom_pattern_scope(patterns):
"""Register a custom pattern in the context to be used at quantization time.
A pattern is understood as a sequence of continuous operations in the graph,
whose representation can converge in an ``OnnxLayer``.
Args:
patterns (dict): a list of sequence of nodes (keys) and their mapper function (values).
"""
# Use of global parameters
global CUSTOM_PATTERNS_MAP
# Transform input patterns in a valid format
qpatterns = []
for pattern, func in patterns.items():
qpatterns.append(_custom_pattern_to_qpattern(pattern, func))
try:
# Extend CUSTOM_PATTERNS_MAP with new qpatterns
CUSTOM_PATTERNS_MAP.extend(qpatterns)
yield
finally:
# Restore to previous state
CUSTOM_PATTERNS_MAP.clear()
def _custom_pattern_to_qpattern(pattern, func):
assert callable(func), f"function has to be a callable. Receives: {func}"
if len(signature(func).parameters) != 2:
raise RuntimeError("function must have two inputs: sequence_nodes and graph")
if isinstance(pattern, str):
pattern = (pattern,)
if not (isinstance(pattern, tuple) and all(isinstance(x, str) for x in pattern)):
raise ValueError(f"Pattern must be a string-tuple. Receives: {pattern}")
return QuantizePattern(pattern, func)