Source code for viqa.nr_metrics.cnr

"""Module for calculating the contrast-to-noise ratio (CNR) for an image.

Examples
--------
    .. doctest-requires:: numpy

        >>> import numpy as np
        >>> from viqa import CNR
        >>> img = np.random.rand(256, 256)
        >>> cnr = CNR(data_range=1, normalize=False)
        >>> cnr
        CNR(score_val=None)
        >>> score = cnr.score(img,
        ...                   background_center=(128, 128),
        ...                   signal_center=(32, 32),
        ...                   radius=16)
"""

# Authors
# -------
# Author: Lukas Behammer
# Research Center Wels
# University of Applied Sciences Upper Austria, 2023
# CT Research Group
#
# Modifications
# -------------
# Original code, 2024, Lukas Behammer
# Add interactive center selection, 2024, Michael Stidi
# Add automatic center detection, 2024, Michael Stidi
# Update automatic center detection, 2024, Lukas Behammer
#
# License
# -------
# BSD-3-Clause License

from warnings import warn

import numpy as np

from viqa._metrics import NoReferenceMetricsInterface
from viqa.utils import (
    FIGSIZE_CNR_2D,
    FIGSIZE_CNR_3D,
    _check_border_too_close,
    _create_slider_widget,
    _get_binary,
    _to_cubic,
    _to_grayscale,
    _to_spherical,
    _visualize_cnr_2d,
    _visualize_cnr_3d,
    find_largest_region,
    load_data,
)
from viqa.utils._module import check_interactive_vis_deps, is_ipython, try_import

widgets, has_ipywidgets = try_import("ipywidgets")
display, has_ipython = try_import("IPython.display", "display")

glob_signal_center = ()
glob_background_center = ()
glob_radius = None

FIGSIZE_CNR_2D_ = tuple(f"{val}in" for val in FIGSIZE_CNR_2D)
FIGSIZE_CNR_3D_ = tuple(f"{val}in" for val in FIGSIZE_CNR_3D)


[docs] class CNR(NoReferenceMetricsInterface): """Class to calculate the contrast-to-noise ratio (CNR) for an image. Attributes ---------- score_val : float CNR score value of the last calculation. parameters : dict Dictionary containing the parameters for CNR calculation. Parameters ---------- data_range : {1, 255, 65535}, default=255 Data range of the returned data in data loading. Is used for image loading when ``normalize`` is True. Passed to :py:func:`viqa.utils.load_data`. normalize : bool, default False If True, the input images are normalized to the ``data_range`` argument. **kwargs : optional Additional parameters for data loading. The keyword arguments are passed to :py:func:`viqa.utils.load_data`. Other Parameters ---------------- chromatic : bool, default False If True, the input images are expected to be RGB images. .. note:: Currently not supported. """ def __init__(self, data_range=255, normalize=False, **kwargs) -> None: """Construct method.""" super().__init__(data_range=data_range, normalize=normalize, **kwargs) self._name = "CNR"
[docs] def score(self, img, **kwargs): """Calculate the contrast-to-noise ratio (CNR) for an image. Parameters ---------- img : np.ndarray or Tensor or str or os.PathLike Image to calculate score of. **kwargs : optional Additional parameters for CNR calculation. The keyword arguments are passed to :py:func:`.viqa.nr_metrics.cnr.contrast_to_noise_ratio`. If ``signal_center``, ``background_center`` and ``radius`` are not given the class attribute :py:attr:`parameters` is used. Returns ------- score_val : float CNR score value. """ img = self.load_images(img) # check if signal_center, background_center and radius are provided if not {"signal_center", "background_center", "radius"}.issubset(kwargs): if not {"signal_center", "background_center", "radius"}.issubset( self.parameters.keys() ): raise ValueError("No center or radius provided.") kwargs["signal_center"] = self.parameters["signal_center"] kwargs["background_center"] = self.parameters["background_center"] kwargs["radius"] = self.parameters["radius"] # write kwargs to .parameters attribute self.parameters.update(kwargs) score_val = contrast_to_noise_ratio(img, **kwargs) self.score_val = score_val return score_val
[docs] def print_score(self, decimals=2): """Print the CNR score value of the last calculation. Parameters ---------- decimals : int, default=2 Number of decimal places to print the score value. Warns ----- RuntimeWarning If :py:attr:`score_val` is not available. """ if self.score_val is not None: print("CNR: {}".format(np.round(self.score_val, decimals))) else: warn("No score value for CNR. Run score() first.", RuntimeWarning)
[docs] def visualize_centers( self, img, signal_center=None, background_center=None, radius=None, export_path=None, **kwargs, ): """Visualize the centers for CNR calculation. The visualization shows the signal and background regions in a matplotlib plot. If export_path is provided, the plot is saved to the given path. Parameters ---------- img : np.ndarray or Tensor or str or os.PathLike Image to visualize. signal_center : Tuple(int), optional Center of the signal. Order is ``(x, y)`` for 2D images and ``(x, y, z)`` for 3D images. If not given the class attribute :py:attr:`parameters` is used. background_center : Tuple(int), optional Center of the background. Order is ``(x, y)`` for 2D images and ``(x, y, z)`` for 3D images. If not given the class attribute :py:attr:`parameters` is used. radius : int, optional Width of the regions. If not given the class attribute :py:attr:`parameters` is used. export_path : str or os.PathLike, optional Path to export the visualization to. **kwargs : optional Additional parameters for visualization. The keyword arguments are passed to :py:func:`matplotlib.pyplot.subplots`. Raises ------ ImportError If the visualization fails in a non-interactive environment. ValueError If the given centers are not in the same dimension as the image. If the center is too close to the border. If the image is not 2D or 3D. TypeError If the center is not a tuple of integers. If the radius is not a positive integer. """ # Load image img = self.load_images(img) if not signal_center or not background_center or not radius: if not {"signal_center", "background_center", "radius"}.issubset( self.parameters.keys() ): raise ValueError("No center or radius provided.") signal_center = self.parameters["signal_center"] background_center = self.parameters["background_center"] radius = self.parameters["radius"] # Check if img and signal_center have the same dimension if img.shape[-1] == 3: if ( img.ndim != len(signal_center) + 1 or img.ndim != len(background_center) + 1 ): raise ValueError("Centers have to be in the same dimension as img.") else: if img.ndim != len(signal_center) or img.ndim != len(background_center): raise ValueError("Centers have to be in the same dimension as img.") # Check if radius is an integer and positive if not isinstance(radius, int) or radius <= 0: raise TypeError("Radius has to be an integer.") # check if signal_center and background_center are tuples of integers and not # too close to the border _check_border_too_close(signal_center, radius) _check_border_too_close(background_center, radius) # Visualize centers if img.ndim == 3 and img.shape[-1] != 3: # 3D image _visualize_cnr_3d( img=img, signal_center=signal_center, background_center=background_center, radius=radius, export_path=export_path, **kwargs, ) elif img.ndim == 3 and img.shape[-1] == 3: # 2D RGB image img = _to_grayscale(img) _visualize_cnr_2d( img=img, signal_center=signal_center, background_center=background_center, radius=radius, export_path=export_path, **kwargs, ) elif img.ndim == 2: # 2D image _visualize_cnr_2d( img=img, signal_center=signal_center, background_center=background_center, radius=radius, export_path=export_path, **kwargs, ) else: raise ValueError("No visualization possible for non 2d or non 3d images.")
[docs] def set_centers( self, img, **kwargs, ): """Visualize and set the centers for CNR calculation interactively. The visualization shows the signal and background regions in a matplotlib plot. Parameters ---------- img : np.ndarray or Tensor or str or os.PathLike Image to visualize. **kwargs : optional Additional parameters for visualization. The keyword arguments are passed to :py:meth:`visualize_centers` and :py:func:`matplotlib.pyplot.subplots`. ``background_center``, ``signal_center``, and ``radius`` can be provided as starting points for the interactive center selection. If not provided, the center of the image is used as the starting point. .. caution:: The class attribute :py:attr:`parameters` takes precedence over the given parameters. Warnings -------- This method is only available in an IPython environment. If not in an IPython environment, the method will try to visualize the centers in a non-interactive environment. Raises ------ ImportError If the visualization fails in a non-interactive environment. ValueError If the given centers are not in the same dimension as the image. If the center is too close to the border. If the image is not 2D or 3D. TypeError If the center is not a tuple of integers. If the radius is not a positive integer. """ if not is_ipython(): try: warn("Trying to visualize in a non-interactive environment.") self.visualize_centers(img, **kwargs) return except Exception: raise ImportError( "Failed to visualize in a non-interactive " "environment. Please use an IPython environment or try giving " "the parameters for the centers and radius directly." ) from None check_interactive_vis_deps(has_ipywidgets, has_ipython) # Load image img = self.load_images(img) # Prepare visualization functions and widgets # Define output layout def _output_layout(out, figsize_variable): figsize = kwargs.get("figsize", figsize_variable) width = figsize[0] height = figsize[1] out.layout = { "width": width, "height": height, } # Define function to save values from global variables to .parameters attribute def _save_values(_): global glob_signal_center, glob_background_center, glob_radius self.parameters.update( { "signal_center": glob_signal_center, "background_center": glob_background_center, "radius": glob_radius, } ) # Define function to save values to global variables def _write_values_to_global(signal_value, background_value, radius_value): # Set global variables global glob_signal_center, glob_background_center, glob_radius glob_signal_center = signal_value glob_background_center = background_value glob_radius = radius_value # Define functions for visualization def _update_visualization_2d( signal_center_x, signal_center_y, background_center_x, background_center_y, radius, ): signal_center = (signal_center_x, signal_center_y) background_center = (background_center_x, background_center_y) _write_values_to_global(signal_center, background_center, radius) _visualize_cnr_2d( img=img, signal_center=signal_center, background_center=background_center, radius=radius, **kwargs, ) def _update_visualization_3d( signal_center_x, signal_center_y, signal_center_z, background_center_x, background_center_y, background_center_z, radius, ): signal_center = ( signal_center_x, signal_center_y, signal_center_z, ) background_center = ( background_center_x, background_center_y, background_center_z, ) _write_values_to_global(signal_center, background_center, radius) _visualize_cnr_3d( img=img, signal_center=signal_center, background_center=background_center, radius=radius, **kwargs, ) # Check if img is 2D RGB image if img.ndim == 3 and img.shape[-1] == 3: img = _to_grayscale(img) center_point = tuple(val // 2 for val in img.shape) # Check if background_center and signal_center are provided if {"background_center", "signal_center", "radius"}.issubset( self.parameters.keys() ): background_start = self.parameters["background_center"] signal_start = self.parameters["signal_center"] radius_start = self.parameters["radius"] else: background_start = kwargs.pop("background_center", center_point) signal_start = kwargs.pop("signal_center", center_point) radius_start = kwargs.pop("radius", 1) _write_values_to_global(signal_start, background_start, radius_start) global glob_signal_center, glob_background_center, glob_radius # Check if background_center and signal_center are the right shape for 2d and 3d if img.ndim != len(glob_signal_center) or img.ndim != len( glob_background_center ): raise ValueError("Centers have to be in the same dimension as img.") # Check if radius is an integer and positive if not isinstance(glob_radius, int) or glob_radius <= 0: raise TypeError("Radius has to be an integer.") # check if signal_center and background_center are tuples of integers and not # too close to the border _check_border_too_close(glob_signal_center, glob_radius) _check_border_too_close(glob_background_center, glob_radius) # Write values to attributes _save_values(None) # Create slider for radius slider_radius = _create_slider_widget( max=min(img.shape) // 2, min=1, value=radius_start, description="Radius", ) slider_radius.style = {"handle_color": "#f7f7f7"} # Create sliders for background center coordinates slider_background_center_x = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[0] - (slider_radius.value + 1), value=background_start[0], ) slider_background_center_y = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[1] - (slider_radius.value + 1), value=background_start[1], ) # Create sliders for signal center coordinates slider_signal_center_x = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[0] - (slider_radius.value + 1), value=signal_start[0], ) slider_signal_center_y = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[1] - (slider_radius.value + 1), value=signal_start[1], ) if img.ndim == 3 and img.shape[-1] != 3: # Add widgets for z coordinate slider_background_center_z = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[2] - (slider_radius.value + 1), value=background_start[2], ) slider_signal_center_z = _create_slider_widget( min=slider_radius.value + 1, max=img.shape[2] - (slider_radius.value + 1), value=signal_start[2], ) slider_background_center_z.style = {"handle_color": "#2c7bb6"} slider_signal_center_z.style = {"handle_color": "#2c7bb6"} # Update min and max values of sliders dynamically def _update_values(change): slider_background_center_x.min = change.new + 1 slider_background_center_x.max = img.shape[0] - (change.new + 1) slider_background_center_y.min = change.new + 1 slider_background_center_y.max = img.shape[1] - (change.new + 1) slider_signal_center_x.min = change.new + 1 slider_signal_center_x.max = img.shape[0] - (change.new + 1) slider_signal_center_y.min = change.new + 1 slider_signal_center_y.max = img.shape[1] - (change.new + 1) try: slider_background_center_z.min = change.new + 1 slider_background_center_z.max = img.shape[2] - (change.new + 1) slider_signal_center_z.min = change.new + 1 slider_signal_center_z.max = img.shape[2] - (change.new + 1) except NameError: pass slider_radius.observe(_update_values, "value") # Create button to save values save_button = widgets.Button(description="Save Current Values") save_button.on_click(_save_values) # Visualize centers if img.ndim == 3 and (img.shape[-1] != 3): # 3D image # Set style of widgets slider_background_center_x.style = {"handle_color": "#d7191c"} slider_background_center_y.style = {"handle_color": "#fdae61"} slider_signal_center_x.style = {"handle_color": "#d7191c"} slider_signal_center_y.style = {"handle_color": "#fdae61"} # Create output out = widgets.interactive_output( _update_visualization_3d, { "background_center_x": slider_background_center_x, "background_center_y": slider_background_center_y, "background_center_z": slider_background_center_z, "signal_center_x": slider_signal_center_x, "signal_center_y": slider_signal_center_y, "signal_center_z": slider_signal_center_z, "radius": slider_radius, }, ) _output_layout(out, FIGSIZE_CNR_3D_) # Create UI ui = widgets.VBox( [ widgets.HBox( [ widgets.VBox( [ widgets.Label("X Coordinate (Background)"), slider_background_center_x, ] ), widgets.VBox( [ widgets.Label("Y Coordinate (Background)"), slider_background_center_y, ] ), widgets.VBox( [ widgets.Label("Z Coordinate (Background)"), slider_background_center_z, ] ), ], layout=widgets.Layout(justify_content="space-around"), ), widgets.HBox( [ widgets.VBox( [ widgets.Label("X Coordinate (Signal)"), slider_signal_center_x, ] ), widgets.VBox( [ widgets.Label("Y Coordinate (Signal)"), slider_signal_center_y, ] ), widgets.VBox( [ widgets.Label("Z Coordinate (Signal)"), slider_signal_center_z, ] ), ], layout=widgets.Layout(justify_content="space-around"), ), widgets.HBox( [slider_radius, save_button], layout=widgets.Layout(justify_content="center"), ), ], layout=widgets.Layout(padding="10px 60px"), ) display(out, ui) elif img.ndim == 2 or (img.ndim == 3 and img.shape[-1] == 3): # 2D image # Set style of widgets slider_background_center_x.style = {"handle_color": "#ca0020"} slider_background_center_y.style = {"handle_color": "#f4a582"} slider_signal_center_x.style = {"handle_color": "#0571b0"} slider_signal_center_y.style = {"handle_color": "#92c5de"} # Create output out = widgets.interactive_output( _update_visualization_2d, { "background_center_x": slider_background_center_x, "background_center_y": slider_background_center_y, "signal_center_x": slider_signal_center_x, "signal_center_y": slider_signal_center_y, "radius": slider_radius, }, ) _output_layout(out, FIGSIZE_CNR_2D_) # Create UI ui = widgets.VBox( [ widgets.HBox( [ widgets.VBox( [ widgets.Label("X Coordinate (Background)"), slider_background_center_x, ] ), widgets.VBox( [ widgets.Label("X Coordinate (Signal)"), slider_signal_center_x, ] ), ], layout=widgets.Layout(justify_content="space-around"), ), widgets.HBox( [ widgets.VBox( [ widgets.Label("Y Coordinate (Background)"), slider_background_center_y, ] ), widgets.VBox( [ widgets.Label("Y Coordinate (Signal)"), slider_signal_center_y, ] ), ], layout=widgets.Layout(justify_content="space-around"), ), widgets.HBox( [slider_radius, save_button], layout=widgets.Layout(justify_content="center"), ), ], layout=widgets.Layout(padding="10px 60px"), ) display(out, ui) else: raise ValueError("No visualization possible for non 2d or non 3d images.")
[docs] def contrast_to_noise_ratio( img, background_center, signal_center, radius, region_type="cubic", auto_center=False, iterations=5, **kwargs, ): """Calculate the contrast-to-noise ratio (CNR) for an image. Parameters ---------- img : np.ndarray or Tensor or str or os.PathLike Image to calculate score of. background_center : Tuple(int) Center of the background. Order is ``(x, y)`` for 2D images and ``(x, y, z)`` for 3D images. signal_center : Tuple(int) Center of the signal. Order is ``(x, y)`` for 2D images and ``(x, y, z)`` for 3D images. radius : int Width of the regions. region_type : {'cubic', 'spherical', 'full', 'original'}, optional Type of region to calculate the SNR. Default is 'cubic'. Gets passed to :py:func:`viqa.utils.find_largest_region` if `auto_center` is True. If `auto_center` is False, the following options are available: If 'full' or 'original' the original image is used as signal and background region. If 'cubic' a cubic region around the center is used. Alias for 'cubic' are 'cube' and 'square'. If 'spherical' a spherical region around the center is used. Alias for 'spherical' are 'sphere' and 'circle'. auto_center : bool, default False Automatically find the center of the image. `signal_center`, `background_center` and `radius` are ignored if True. iterations : int, optional Number of iterations for morphological operations if `auto_center` is True. Default is 5. **kwargs : optional Additional parameters for data loading. The keyword arguments are passed to :py:func:`viqa.utils.load_data`. Returns ------- score_val : float CNR score value. Raises ------ ValueError If the input image is not 2D or 3D. If the input center is not a tuple of integers. If the input center is too close to the border. If the input radius is not an integer. If the passed region type is not valid. Notes ----- This implementation uses cubic regions to calculate the CNR. The calculation is based on the following formula: .. math:: CNR = \\frac{\\mu_{signal} - \\mu_{background}}{\\sigma_{background}} where :math:`\\mu` is the mean and :math:`\\sigma` is the standard deviation. .. important:: The background region should be chosen in a homogeneous area, while the signal region should be chosen in an area with a high contrast. References ---------- .. [1] Desai, N., Singh, A., & Valentino, D. J. (2010). Practical evaluation of image quality in computed radiographic (CR) imaging systems. Medical Imaging 2010: Physics of Medical Imaging, 7622, 76224Q. https://doi.org/10.1117/12.844640 """ img = load_data(img, **kwargs) if auto_center is True: warn( "Signal and background center are automatically detected. Parameters " "signal_center, background_center and radius are ignored." ) binary_foreground = _get_binary( img, lower_threshold=90, upper_threshold=100, show=False ) binary_background = _get_binary( img, lower_threshold=0, upper_threshold=90, show=False ) signal_center, signal_radius, signal_region = find_largest_region( img=binary_foreground, region_type=region_type, iterations=iterations ) background_center, background_radius, background_region = find_largest_region( img=binary_background, region_type=region_type, iterations=iterations ) radius = min(signal_radius, background_radius) # Mask the original image with the regions if region_type not in { "original", "cubic", "cube", "square", "spherical", "sphere", "circle", }: signal = np.ma.array(img, mask=~signal_region, copy=True) background = np.ma.array(img, mask=~background_region, copy=True) elif region_type == "original": signal = img background = img # check if signal_center and background_center are tuples of integers and radius is # an integer for center in signal_center: if not isinstance(center, int): raise TypeError("Signal center has to be a tuple of integers.") if abs(center) - radius < 0: # check if center is too close to the border raise ValueError( "Signal center has to be at least the radius away from the border." ) for center in background_center: if not isinstance(center, int): raise TypeError("Background center has to be a tuple of integers.") if abs(center) - radius < 0: raise ValueError( "Background center has to be at least the radius away from the border." ) if not isinstance(radius, int) or radius <= 0: raise TypeError("Radius has to be an integer.") # Check if img and centers have the same dimension if img.ndim != len(signal_center) or img.ndim != len(background_center): raise ValueError("Centers have to be in the same dimension as img.") # Define regions if region_type in {"full", "original"}: if not auto_center: signal = img background = img elif region_type in {"cubic", "cube", "square"}: signal = _to_cubic(img, signal_center, radius) background = _to_cubic(img, background_center, radius) elif region_type in {"spherical", "sphere", "circle"}: signal = _to_spherical(img, signal_center, radius) background = _to_spherical(img, background_center, radius) else: if not auto_center: raise ValueError("Region type not valid.") # Calculate CNR if np.std(background) == 0: cnr_val = 0 else: cnr_val = (np.mean(signal) - np.mean(background)) / np.std(background) return np.float64(cnr_val)