#!/usr/bin/env python
# *****************************************************************************
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Tools for Knowledge Distillation training.
Originated from https://keras.io/examples/vision/knowledge_distillation/.
Reference Hinton et al. (2015) https://arxiv.org/abs/1503.02531
"""
import tensorflow as tf
from functools import partial
from tensorflow import GradientTape
from keras import Model
from keras.losses import KLDivergence, CategoricalCrossentropy
[docs]class Distiller(Model):
    """ The class that will be used to train the student model using the
    distillation knowledge method.
    Reference `Hinton et al. (2015) <https://arxiv.org/abs/1503.02531>`_.
    Args:
        student (keras.Model): the student model
        teacher (keras.Model): the well trained teacher model
        alpha (float, optional): weight to student_loss_fn and 1-alpha
            to distillation_loss_fn. Defaults to 0.1
    """
    def __init__(self, student, teacher, alpha=0.1):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.student_loss_fn = None
        self.distillation_loss_fn = None
        self.alpha = alpha
    @property
    def base_model(self):
        return self.student
    @property
    def layers(self):
        return self.base_model.layers
    def compile(self,
                optimizer,
                metrics,
                student_loss_fn,
                distillation_loss_fn):
        """ Configure the distiller.
        Args:
            optimizer (keras.optimizers.Optimizer): Keras optimizer
                for the student weights
            metrics (keras.metrics.Metric): Keras metrics for evaluation
            student_loss_fn (keras.losses.Loss): loss function of difference
                between student predictions and ground-truth
            distillation_loss_fn (keras.losses.Loss): loss function of
                difference between student predictions and teacher predictions
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)
        with GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)
            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, student_predictions)
            loss = self.alpha * student_loss + (1 -
                                                self.alpha) * distillation_loss
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({
            "student_loss": student_loss,
            "distillation_loss": distillation_loss
        })
        return results
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)
        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    def save(self, *args, **kwargs):
        return self.base_model.save(*args, **kwargs)
    def save_weights(self, *args, **kwargs):
        return self.base_model.save_weights(*args, **kwargs)
    def load_weights(self, *args, **kwargs):
        return self.base_model.load_weights(*args, **kwargs) 
class DeitDistiller(Distiller):
    """Distiller class to train the student model using the
    Knowledge Distillation (KD) method, found on https://arxiv.org/pdf/2012.12877.pdf
    The main difference with the classic KD is that the student has to produce two potential
    classification outputs. This type of training is based on the assumption that each output
    has sufficiently interacted with the whole model, therefore the main architecture can be
    trained through two different sources, as follows:
        >>> output, output_kd = student(input)
        >>> output_tc = teacher(input)
        >>> student_loss = student_loss_fn(y_true, output)
        >>> distillation_loss = distillation_loss_fn(output_tc, output_kd)
    This means we will expect to have different inputs for each loss, unlike classical KD,
    where the student's prediction is shared for both losses. However, given that each
    classifier has interacted with the student model, the gradient of each loss will contribute
    to the update of the model weights according to the alpha percentage.
    Args:
        student (keras.Model): the student model
        teacher (keras.Model): the well trained teacher model
        alpha (float, optional): weight to student_loss_fn and 1-alpha
            to distillation_loss_fn. Defaults to 0.1
        temperature (float, optional): if ``distiller_type`` when compile is equal to 'soft',
            this value will be used as temperature parameter of KLDistillationLoss.
            Defaults to 1.0.
    """
    def __init__(self, student, *args, temperature=1.0, **kwargs):
        assert len(student.outputs) == 2, "Student must be a multi-output model, with 2 outputs"
        # Append an output with the sum of heads
        y = tf.math.add_n(student.outputs) / 2
        _student = Model(student.inputs, student.outputs + [y], name=student.name)
        super().__init__(_student, *args, **kwargs)
        self._student = student
        self.temperature = temperature
    @property
    def base_model(self):
        return self._student
    def compile(self, optimizer, metrics, student_loss_fn, distiller_type):
        """ Configure the distiller.
        Args:
            optimizer (keras.optimizers.Optimizer): Keras optimizer
                for the student weights
            metrics (keras.metrics.Metric): Keras metrics for evaluation
            student_loss_fn (keras.losses.Loss): loss function of difference
                between student predictions and ground-truth
            distiller_type (str): loss function type to define the difference
                between student predictions and teacher-truth, within ['soft', 'hard', 'none'] which
                will result in performing KLDistillationLoss, CategoricalCrossentropy or
                student_loss_fn only respectively.
        """
        assert distiller_type in ['soft', 'hard', 'none']
        def _loss_forward(y_true, y_pred, loss_fn, index=0, **kwargs):
            if isinstance(y_pred, (tuple, list)):
                y_pred = y_pred[index]
            return loss_fn(y_true, y_pred, **kwargs)
        def _compile_distillation_loss_fn():
            if distiller_type == "soft":
                distillation_loss_fn = KLDistillationLoss(temperature=self.temperature)
            else:
                # Follow https://arxiv.org/pdf/2012.12877.pdf, this variant takes the
                # hard decision of the teacher as a true label. Therefore, we add the
                # prediction encoder, as well as a label smoothing equal to 0.1
                y = tf.math.softmax(self.teacher.outputs[0], axis=-1)
                self.teacher = Model(self.teacher.inputs, y, name=self.teacher.name)
                distillation_loss_fn = CategoricalCrossentropy(
                    from_logits=True, label_smoothing=0.1)
            return partial(_loss_forward, loss_fn=distillation_loss_fn, index=1)
        if distiller_type == "none" or self.teacher is None:
            # In this case, we just train the first output of student
            self.teacher = distillation_loss_fn = None
            self.student = Model(self.student.inputs,
                                 self.student.outputs[0], name=self.student.name)
            self.student.compile(optimizer, student_loss_fn, metrics)
        else:
            distillation_loss_fn = _compile_distillation_loss_fn()
            student_loss_fn = partial(_loss_forward, loss_fn=student_loss_fn)
            super().compile(optimizer, metrics, student_loss_fn, distillation_loss_fn)
    def _update_metrics(self, metrics):
        # Rename keys in the result dictionary for a more explicit display
        return {k.replace('output_1_', 'head_').replace('output_2_', 'dist_head_')
                 .replace('output_3_', ''): v for k, v in metrics.items()}
    def train_step(self, data):
        if self.teacher is None:
            return self.student.train_step(data)
        return self._update_metrics(super().train_step(data))
    def test_step(self, data):
        if self.teacher is None:
            return self.student.test_step(data)
        return self._update_metrics(super().test_step(data))
[docs]class KLDistillationLoss(KLDivergence):
    """
    The `KLDistillationLoss` is a simple wrapper around the KLDivergence loss
    that accepts raw predictions instead of probability distributions.
    Before invoking the KLDivergence loss, it converts the inputs predictions to
    probabilities by dividing them by a constant 'temperature' and applies a
    softmax.
    Args:
        temperature (float): temperature for softening probability
            distributions. Larger temperature gives softer distributions.
    """
    def __init__(self, temperature=3):
        self.temperature = temperature
        super().__init__()
    def __call__(self, y_true, y_pred, sample_weight=None):
        # Following https://github.com/facebookresearch/deit/blob/main/losses.py#L63
        # The result of KLDivergence must be scaled
        scale_factor = tf.constant(self.temperature ** 2, dtype=tf.float32)
        return super().__call__(
            tf.nn.softmax(y_true / self.temperature, axis=1),
            tf.nn.softmax(y_pred / self.temperature, axis=1)) * scale_factor