Source code for quantizeml.layers.recorders

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["recording", "Recorder", "TensorRecorder", "FixedPointRecorder", "QFloatRecorder"]

import os
import keras
import tensorflow as tf
from contextlib import contextmanager

from ..tensors import FixedPoint, QFloat

RECORDING_ENV = "RECORDING_ENABLED"


[docs]@contextmanager def recording(enable): """Enable or disable recording. Args: enable (bool): True to enable recording, False to disable it """ value = "1" if enable else "0" _prev_state = os.environ.get(RECORDING_ENV, None) try: os.environ[RECORDING_ENV] = value yield finally: # Recover default value if _prev_state is not None: os.environ[RECORDING_ENV] = _prev_state else: os.environ.pop(RECORDING_ENV)
[docs]class Recorder(): """A class that exhibits a 'recording' property. All objects inheriting from this class share the same 'recording' property. The property cannot be set: its value is deduced from the RECORDING_ENABLED environment variable. """ @property def recording(self): """Flag to specify if the object is in recording mode or not. Returns: bool: True if recording mode is enabled, False otherwise. """ value = os.environ.get(RECORDING_ENV, "0") return (value == "1")
[docs]class TensorRecorder(Recorder, keras.layers.Layer): """Wrapper class to store and retrieve a tf.Tensor extracted from a graph. This is mainly used to recover FixedPoint alignment shift information. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._var = None @property def value(self): """Get the recorded value. Returns: tf.Tensor: value of the stored record or None. """ return None if self._var is None else self._var.value() def call(self, inputs): """Record the values of the inputs if recording is True. Args: inputs (tf.Tensor): new values. Returns: tf.Tensor: the inputs. """ if self.recording: if self._var is None: # Create a new variable to copy values from the graph self._var = tf.Variable( inputs, trainable=False, name=self.name + "/record", synchronization=tf.VariableSynchronization.ON_READ ) else: # Store the new values self._var.assign(inputs) return inputs
[docs]class FixedPointRecorder(Recorder, keras.layers.Layer): """Wrapper class to store and retrieve a FixedPoint extracted from a graph. This is mainly used to recover FixedPoint quantized weights. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._value_bits = None self._values = TensorRecorder() self._frac_bits = TensorRecorder() @property def value(self): """Get the recorded value. Returns: :obj:`FixedPoint`: value of the stored record or None. """ return None if self._value_bits is None else FixedPoint(self._values.value, self._value_bits, self._frac_bits.value) def call(self, inputs): """Record the values of the inputs if recording is True. Args: inputs (:obj:`FixedPoint`): new values. Returns: :obj:`FixedPoint`: the inputs. """ if self.recording: self._value_bits = inputs.value_bits self._values(inputs.values) self._frac_bits(inputs.frac_bits) return inputs
[docs]class QFloatRecorder(Recorder, keras.layers.Layer): """Wrapper class to store and retrieve a QFloat extracted from a graph. This is mainly used to recover QFloat quantized weights. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._fp = FixedPointRecorder() self._scales = TensorRecorder() @property def value(self): """Get the recorded value. Returns: :obj:`QFloat`: value of the stored record or None. """ return None if self._fp.value is None else QFloat(self._fp.value, self._scales.value) def call(self, inputs): """Record the values of the inputs if recording is True. Args: inputs (:obj:`QFloat`): new values. Returns: :obj:`QFloat`: the inputs. """ if self.recording: self._fp(inputs.fp) self._scales(inputs.scales) return inputs