#!/usr/bin/env python
# coding: utf-8
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Helper to load 10 samples of ImageNet-like data.
"""
__all__ = ["get_preprocessed_samples"]
import csv
import os
import numpy as np
from tensorflow.image import decode_jpeg
from tensorflow.io import read_file
from akida_models.imagenet import preprocessing
from akida_models.utils import fetch_file
[docs]
def get_preprocessed_samples(image_size=224, num_channels=3):
    """
    Load and preprocess a 10 ImageNet-like images for testing.
    Args:
        image_size (int, optional): The target size for the images. Defaults to 224.
        num_channels (int, optional): The number of channels in the images. Defaults to 3.
    Returns:
        x_test, labels_test (tuple): 4D and 1D numpy array of the preprocessed images and their
                        corresponding labels
    """
    num_images = 10
    file_path = fetch_file(
        fname="imagenet_like.zip",
        origin="https://data.brainchip.com/dataset-mirror/imagenet_like/imagenet_like.zip",
        cache_subdir='datasets/imagenet_like',
        extract=True)
    data_folder = os.path.dirname(file_path)
    x_test, x_test_files = _get_images(data_folder, num_images, image_size, num_channels)
    labels_test = _get_labels(data_folder, num_images, x_test_files)
    return x_test, labels_test 
def _get_images(data_folder, num_images, image_size, num_channels):
    """
    Load and preprocess ImageNet-like test images.
    Args:
        data_folder (str): Folder where images are located.
        num_images (int): Number of images to load.
        image_size (int): Target size for the images.
        num_channels (int): Number of channels in the images.
    Returns:
        Tuple (`np.ndarray`, List[str]): Preprocessed images and corresponding file names.
    """
    # Load images for test set
    x_test_files = []
    x_test = np.zeros((num_images, image_size, image_size, num_channels)).astype('uint8')
    for idx in range(num_images):
        test_file = 'image_' + str(idx + 1).zfill(2) + '.jpg'
        x_test_files.append(test_file)
        img_path = os.path.join(data_folder, test_file)
        base_image = read_file(img_path)
        image = decode_jpeg(base_image, channels=num_channels)
        image = preprocessing.preprocess_image(image, (image_size, image_size))
        x_test[idx, :, :, :] = np.expand_dims(image, axis=0)
    return x_test, x_test_files
def _get_labels(data_folder, num_images, x_test_files):
    """
    Parse labels file for ImageNet-like test samples.
    Args:
        data_folder (str): Folder where labels file is located.
        num_images (int): Number of images.
        x_test_files (List[str]): List of file names for test samples.
    Returns:
        labels_test (np.ndarray): NumPy array of labels for the test samples.
    """
    # Parse labels file
    fname = os.path.join(data_folder, 'labels_validation.txt')
    validation_labels = dict()
    with open(fname, newline='') as csvfile:
        reader = csv.reader(csvfile, delimiter=' ')
        for row in reader:
            validation_labels[row[0]] = row[1]
    # Get labels for the test set by index
    labels_test = np.zeros(num_images)
    for i in range(num_images):
        labels_test[i] = int(validation_labels[x_test_files[i]])
    return labels_test