Source code for viqa._metrics

# Authors
# -------
# Author: Lukas Behammer
# Research Center Wels
# University of Applied Sciences Upper Austria, 2023
# CT Research Group
#
# Modifications
# -------------
# Original code, 2024, Lukas Behammer
#
# License
# -------
# BSD-3-Clause License

"""Module for the internal metrics classes.

This module contains the abstract classes for the metrics.
"""

from abc import ABC, abstractmethod

from viqa.utils import _check_imgs, export_results, load_data


[docs] class Metric: def __init__(self, data_range, normalize, **kwargs): self.parameters = { "data_range": data_range, "normalize": normalize, "chromatic": False, "roi": None, **kwargs, } self.score_val = None self._name = None if self.parameters["normalize"] and not self.parameters["data_range"]: raise ValueError("If normalize is True, data_range must be specified")
[docs] @abstractmethod def score(self, *args): """Calculate the score.""" pass
[docs] @abstractmethod def print_score(self, *args): """Print the score.""" pass
[docs] def export_results(self, path, filename): """Export the score to a csv file. Parameters ---------- path : str The path where the csv file should be saved. filename : str The name of the csv file. Notes ----- The arguments get passed to :py:func:`.viqa.utils.export_results`. """ export_results([self], path, filename)
def __eq__(self, other): if isinstance(other, Metric): return self.score_val == other.score_val else: return self.score_val == other def __lt__(self, other): if isinstance(other, Metric): return self.score_val < other.score_val else: return self.score_val < other def __gt__(self, other): if isinstance(other, Metric): return self.score_val > other.score_val else: return self.score_val > other def __le__(self, other): if isinstance(other, Metric): return self.score_val <= other.score_val else: return self.score_val <= other def __ge__(self, other): if isinstance(other, Metric): return self.score_val >= other.score_val else: return self.score_val >= other def __ne__(self, other): if isinstance(other, Metric): return self.score_val != other.score_val else: return self.score_val != other def __repr__(self): return f"{self.__class__.__name__}(result={self.score_val})"
[docs] class FullReferenceMetricsInterface(ABC, Metric):
[docs] def __init__(self, data_range, normalize, **kwargs): super().__init__(data_range, normalize, **kwargs) self.type = "full-reference"
[docs] def load_images(self, img_r, img_m): """Load the images and perform checks. Parameters ---------- img_r : np.ndarray, viqa.ImageArray, torch.Tensor, str or os.PathLike The reference image. img_m : np.ndarray, viqa.ImageArray, torch.Tensor, str or os.PathLike The modified image. Returns ------- img_r : viqa.ImageArray The loaded reference image as an :py:class:`viqa.utils.ImageArray`. img_m : viqa.ImageArray The loaded modified image as an :py:class:`viqa.utils.ImageArray`. """ img_r, img_m = _check_imgs( img_r=img_r, img_m=img_m, data_range=self.parameters["data_range"], normalize=self.parameters["normalize"], chromatic=self.parameters["chromatic"], roi=self.parameters["roi"], ) return img_r, img_m
[docs] class NoReferenceMetricsInterface(ABC, Metric):
[docs] def __init__(self, data_range, normalize, **kwargs): super().__init__(data_range, normalize, **kwargs) self.type = "no-reference"
[docs] def load_images(self, img): """Load the image. Uses the :py:func:`.viqa.utils.load_data` function to load the image. Parameters ---------- img : np.ndarray, viqa.ImageArray, torch.Tensor, str or os.PathLike The image to load. Returns ------- img : viqa.ImageArray The loaded image as an :py:class:`viqa.utils.ImageArray`. """ # Load image img = load_data( img=img, data_range=self.parameters["data_range"], normalize=self.parameters["normalize"], roi=self.parameters["roi"], ) return img