Source code for viqa.utils.visualization

"""Module for visualization functions.

Examples
--------
    .. doctest-skip::

    >>> import numpy as np
    >>> from viqa.utils.visualization import visualize_2d, visualize_3d
    >>> img = np.random.rand(100, 100)
    >>> visualize_2d(img)
    >>> img = np.random.rand(100, 100, 100)
    >>> visualize_3d(img, (50, 50, 50))
"""

# Authors
# -------
# Author: Lukas Behammer
# Research Center Wels
# University of Applied Sciences Upper Austria, 2023
# CT Research Group
#
# Modifications
# -------------
# Original code, 2024, Lukas Behammer
#
# License
# -------
# BSD-3-Clause License

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np

from viqa.utils._module import try_import

widgets, has_ipywidgets = try_import("ipywidgets")

FIGSIZE_CNR_2D = (10, 5.5)
FIGSIZE_CNR_3D = (10, 8)
FIGSIZE_SNR_2D = (7, 7)
FIGSIZE_SNR_3D = (10, 4)


def _visualize_cnr_2d(
    img,
    signal_center,
    background_center,
    radius,
    region_type,
    export_path=None,
    show=True,
    **kwargs,
):
    figsize = kwargs.pop("figsize", FIGSIZE_CNR_2D)
    dpi = kwargs.pop("dpi", 300)

    fig, axs = plt.subplots(1, 2, figsize=figsize, dpi=dpi)
    fig.suptitle("Regions for CNR Calculation", y=0.92)
    axs[0].imshow(img, cmap="gray")
    axs[0].set_title("Background")
    axs[0].set_xlabel("x")
    axs[0].set_ylabel("y")
    axs[0].invert_yaxis()

    if region_type in ["cubic", "cube", "square"]:
        patch_1 = patches.Rectangle(
            (
                background_center[0] - radius,
                background_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ca0020",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_1 = patches.Circle(
            (background_center[0], background_center[1]),
            radius,
            linewidth=1,
            edgecolor="#ca0020",
            facecolor="none",
        )

    if patch_1:
        axs[0].add_patch(patch_1)

    axs[1].imshow(img[..., ::-1], cmap="gray")
    axs[1].set_title("Signal")
    axs[1].set_xlabel("x")
    axs[1].set_ylabel("y")
    axs[1].invert_yaxis()

    if region_type in ["cubic", "cube", "square"]:
        patch_2 = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#0571b0",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_2 = patches.Circle(
            (signal_center[0], signal_center[1]),
            radius,
            linewidth=1,
            edgecolor="#0571b0",
            facecolor="none",
        )

    if patch_2:
        axs[1].add_patch(patch_2)

    if show:
        plt.show()
    if export_path:
        plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)


def _visualize_cnr_3d(
    img,
    signal_center,
    background_center,
    radius,
    region_type,
    export_path=None,
    show=True,
    **kwargs,
):
    figsize = kwargs.pop("figsize", FIGSIZE_CNR_3D)
    dpi = kwargs.pop("dpi", 300)

    fig, axs = plt.subplots(2, 3, figsize=figsize, dpi=dpi, **kwargs)
    fig.suptitle(
        "Background (Upper) and Signal Region (Lower) for CNR Calculation", y=0.92
    )
    # Background Region
    axs[0][0].imshow(np.rot90(img[background_center[0], :, ::-1]), cmap="gray")
    axs[0][0].set_title(f"x-axis, slice: {background_center[0]}", c="#d7191c")
    axs[0][0].set_xlabel("y")
    axs[0][0].set_ylabel("z")
    axs[0][0].invert_yaxis()
    axs[0][0].axvline(x=background_center[1], color="#fdae61", linestyle="--")
    axs[0][0].axhline(y=background_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_1 = patches.Rectangle(
            (
                background_center[1] - radius,
                background_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_1 = patches.Circle(
            (background_center[1], background_center[2]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )

    if patch_1:
        axs[0][0].add_patch(patch_1)

    axs[0][1].imshow(np.rot90(img[:, background_center[1], ::-1]), cmap="gray")
    axs[0][1].set_title(f"y-axis, slice: {background_center[1]}", c="#fdae61")
    axs[0][1].set_xlabel("x")
    axs[0][1].set_ylabel("z")
    axs[0][1].invert_yaxis()
    axs[0][1].axvline(x=background_center[0], color="#d7191c", linestyle="--")
    axs[0][1].axhline(y=background_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_2 = patches.Rectangle(
            (
                background_center[0] - radius,
                background_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_2 = patches.Circle(
            (background_center[0], background_center[2]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )

    if patch_2:
        axs[0][1].add_patch(patch_2)

    axs[0][2].imshow(np.rot90(img[::-1, :, background_center[2]], -1), cmap="gray")
    axs[0][2].set_title(f"z-axis, slice: {background_center[2]}", c="#2c7bb6")
    axs[0][2].set_xlabel("x")
    axs[0][2].set_ylabel("y")
    axs[0][2].invert_yaxis()
    axs[0][2].axvline(x=background_center[0], color="#d7191c", linestyle="--")
    axs[0][2].axhline(y=background_center[1], color="#fdae61", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_3 = patches.Rectangle(
            (
                background_center[0] - radius,
                background_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_3 = patches.Circle(
            (background_center[0], background_center[1]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )

    if patch_3:
        axs[0][2].add_patch(patch_3)

    # Signal Region
    axs[1][0].imshow(np.rot90(img[signal_center[0], :, ::-1]), cmap="gray")
    axs[1][0].set_title(f"x-axis, slice: {signal_center[0]}", c="#d7191c")
    axs[1][0].set_xlabel("y")
    axs[1][0].set_ylabel("z")
    axs[1][0].invert_yaxis()
    axs[1][0].axvline(x=signal_center[1], color="#fdae61", linestyle="--")
    axs[1][0].axhline(y=signal_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_4 = patches.Rectangle(
            (
                signal_center[1] - radius,
                signal_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_4 = patches.Circle(
            (signal_center[1], signal_center[2]),
            radius,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )

    if patch_4:
        axs[1][0].add_patch(patch_4)

    axs[1][1].imshow(np.rot90(img[:, signal_center[1], ::-1]), cmap="gray")
    axs[1][1].set_title(f"y-axis, slice: {signal_center[1]}", c="#fdae61")
    axs[1][1].set_xlabel("x")
    axs[1][1].set_ylabel("z")
    axs[1][1].invert_yaxis()
    axs[1][1].axvline(x=signal_center[0], color="#d7191c", linestyle="--")
    axs[1][1].axhline(y=signal_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_5 = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_5 = patches.Circle(
            (signal_center[0], signal_center[2]),
            radius,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )

    if patch_5:
        axs[1][1].add_patch(patch_5)

    axs[1][2].imshow(np.rot90(img[::-1, :, signal_center[2]], -1), cmap="gray")
    axs[1][2].set_title(f"z-axis, slice: {signal_center[2]}", c="#2c7bb6")
    axs[1][2].set_xlabel("x")
    axs[1][2].set_ylabel("y")
    axs[1][2].invert_yaxis()
    axs[1][2].axvline(x=signal_center[0], color="#d7191c", linestyle="--")
    axs[1][2].axhline(y=signal_center[1], color="#fdae61", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_6 = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_6 = patches.Circle(
            (signal_center[0], signal_center[1]),
            radius,
            linewidth=1,
            edgecolor="#abd9e9",
            facecolor="none",
        )

    if patch_6:
        axs[1][2].add_patch(patch_6)

    if show:
        plt.show()
    if export_path:
        plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)


def _visualize_snr_2d(
    img,
    signal_center,
    radius,
    region_type,
    signal=None,
    export_path=None,
    show=True,
    **kwargs,
):
    if region_type not in [
        "cubic",
        "cube",
        "square",
        "spherical",
        "sphere",
        "circle",
        "full",
        "original",
    ]:
        raise ValueError("Invalid region type.")
    elif region_type == "full":
        if signal is None:
            region_type = "original"
        elif signal.shape != img.shape:
            raise ValueError("The signal must have the same shape as the image.")

    figsize = kwargs.pop("figsize", FIGSIZE_SNR_2D)
    dpi = kwargs.pop("dpi", 300)

    fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    fig.suptitle("Signal Region for SNR Calculation", y=0.92)

    ax.imshow(img[..., ::-1], cmap="gray")

    if region_type == "full":
        ax.imshow(signal, cmap="viridis", alpha=0.5)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.invert_yaxis()

    if region_type in ["cubic", "cube", "square"]:
        patch = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#0571b0",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch = patches.Circle(
            (signal_center[0], signal_center[1]),
            radius,
            linewidth=1,
            edgecolor="#0571b0",
            facecolor="none",
        )
    else:
        patch = None

    if patch:
        ax.add_patch(patch)

    if show:
        plt.show()
    if export_path:
        plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)


def _visualize_snr_3d(
    img,
    signal_center,
    radius,
    region_type,
    signal=None,
    export_path=None,
    show=True,
    **kwargs,
):
    if region_type not in [
        "cubic",
        "cube",
        "square",
        "spherical",
        "sphere",
        "circle",
        "full",
        "original",
    ]:
        raise ValueError("Invalid region type.")
    elif region_type == "full":
        if signal is None:
            region_type = "original"
        elif signal.shape != img.shape:
            raise ValueError("The signal must have the same shape as the image.")

    figsize = kwargs.pop("figsize", FIGSIZE_SNR_3D)
    dpi = kwargs.pop("dpi", 300)

    fig, axs = plt.subplots(1, 3, figsize=figsize, dpi=dpi)
    fig.suptitle("Signal Region for SNR Calculation", y=0.92)

    axs[0].imshow(np.rot90(img[signal_center[0], :, ::-1]), cmap="gray")

    if region_type == "full":
        axs[0].imshow(
            np.rot90(signal[signal_center[0], :, ::-1]), cmap="viridis", alpha=0.5
        )

    axs[0].set_title(f"x-axis, slice: {signal_center[0]}", c="#d7191c")
    axs[0].set_xlabel("y")
    axs[0].set_ylabel("z")
    axs[0].invert_yaxis()
    axs[0].axvline(x=signal_center[1], color="#fdae61", linestyle="--")
    axs[0].axhline(y=signal_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_1 = patches.Rectangle(
            (
                signal_center[1] - radius,
                signal_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_1 = patches.Circle(
            (signal_center[1], signal_center[2]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    else:
        patch_1 = None

    if patch_1:
        axs[0].add_patch(patch_1)

    axs[1].imshow(np.rot90(img[:, signal_center[1], ::-1]), cmap="gray")

    if region_type == "full":
        axs[1].imshow(
            np.rot90(signal[:, signal_center[1], ::-1]), cmap="viridis", alpha=0.5
        )

    axs[1].set_title(f"y-axis, slice: {signal_center[1]}", c="#fdae61")
    axs[1].set_xlabel("x")
    axs[1].set_ylabel("z")
    axs[1].invert_yaxis()
    axs[1].axvline(x=signal_center[0], color="#d7191c", linestyle="--")
    axs[1].axhline(y=signal_center[2], color="#2c7bb6", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_2 = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[2] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_2 = patches.Circle(
            (signal_center[0], signal_center[2]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    else:
        patch_2 = None

    if patch_2:
        axs[1].add_patch(patch_2)

    axs[2].imshow(np.rot90(img[::-1, :, signal_center[2]], -1), cmap="gray")

    if region_type == "full":
        axs[2].imshow(
            np.rot90(signal[::-1, :, signal_center[2]], -1), cmap="viridis", alpha=0.5
        )

    axs[2].set_title(f"z-axis, slice: {signal_center[2]}", c="#2c7bb6")
    axs[2].set_xlabel("x")
    axs[2].set_ylabel("y")
    axs[2].invert_yaxis()
    axs[2].axvline(x=signal_center[0], color="#d7191c", linestyle="--")
    axs[2].axhline(y=signal_center[1], color="#fdae61", linestyle="--")

    if region_type in ["cubic", "cube", "square"]:
        patch_3 = patches.Rectangle(
            (
                signal_center[0] - radius,
                signal_center[1] - radius,
            ),
            radius * 2,
            radius * 2,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    elif region_type in ["spherical", "sphere", "circle"]:
        patch_3 = patches.Circle(
            (signal_center[0], signal_center[1]),
            radius,
            linewidth=1,
            edgecolor="#ffffbf",
            facecolor="none",
        )
    else:
        patch_3 = None

    if patch_3:
        axs[2].add_patch(patch_3)

    if show:
        plt.show()
    if export_path:
        plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)


def _create_slider_widget(**kwargs):
    if not has_ipywidgets:
        raise ImportError(
            "ipywidgets is not installed. Please install it to use this function."
        )

    min_val = kwargs.pop("min", 0)
    step = kwargs.pop("step", 1)
    continuous_update = kwargs.pop("continuous_update", False)

    slider = widgets.IntSlider(
        min=min_val,
        step=step,
        continuous_update=continuous_update,
        **kwargs,
    )
    return slider


[docs] def visualize_2d(img, export_path=None, **kwargs): """ Visualize a 2D image. The function visualizes a 2D image. If `export_path` is provided, the visualization is saved to the specified path. Parameters ---------- img : np.ndarray The 2D image to visualize. export_path : str or Path, optional The path to save the visualization. kwargs : Additional keyword arguments for the plot. Passed to ``matplotlib.pyplot.imshow``. Raises ------ ValueError If the image is not 2D. Returns ------- None """ if img.ndim != 2: raise ValueError("The image must be 2D.") figsize = kwargs.pop("figsize", (6, 6)) dpi = kwargs.pop("dpi", 300) plt.figure(figsize=figsize, dpi=dpi) if "cmap" not in kwargs: plt.imshow(img, cmap="gray", **kwargs) else: plt.imshow(img, **kwargs) plt.xlabel("x") plt.ylabel("y") plt.gca().invert_yaxis() plt.show() if export_path: plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)
[docs] def visualize_3d(img, slices, export_path=None, **kwargs): """ Visualize 3D image slices in 3 different planes. The function visualizes the 3D image slices in the ``x``, ``y`` and ``z`` direction. If ``export_path`` is provided, the visualization is saved to the specified path. Parameters ---------- img : np.ndarray The 3D image to visualize. slices : tuple The slices to visualize in the ``x``, ``y`` and ``z`` direction. The slices must be positive or negative integers. export_path : str or Path, optional The path to save the visualization. kwargs : Additional keyword arguments for the plot. Passed to :py:func:`matplotlib.pyplot.subplots`. Returns ------- None Raises ------ ValueError If the number of slices is not 3 or if the slices are not integers. If the image is not 3D. If the slices are out of bounds. """ if len(slices) != 3: raise ValueError("The number of slices must be 3.") if not all(isinstance(slice_, int) for slice_ in slices): raise ValueError("All slices must be integers.") if img.ndim != 3: raise ValueError("The image must be 3D.") if not all( -img.shape[i] <= slice_ <= img.shape[i] for i, slice_ in enumerate(slices) ): raise ValueError("The slices are out of bounds.") x = slices[0] y = slices[1] z = slices[2] figsize = kwargs.pop("figsize", (14, 6)) dpi = kwargs.pop("dpi", 300) _, axs = plt.subplots(1, 3, figsize=figsize, dpi=dpi, **kwargs) axs[0].imshow(np.rot90(img[x, :, ::-1]), cmap="gray") axs[0].set_xlabel("y") axs[0].set_ylabel("z") axs[0].invert_yaxis() axs[0].axhline(y=z, color="#7570b3", linestyle="--") axs[0].axvline(x=y, color="#d95f02", linestyle="--") axs[0].set_title(f"x-axis, slice: {x}", c="#1b9e77") axs[1].imshow(np.rot90(img[:, y, ::-1]), cmap="gray") axs[1].set_xlabel("x") axs[1].set_ylabel("z") axs[1].invert_yaxis() axs[1].axhline(y=z, color="#7570b3", linestyle="--") axs[1].axvline(x=x, color="#1b9e77", linestyle="--") axs[1].set_title(f"y-axis, slice: {y}", c="#d95f02") axs[2].imshow(np.rot90(img[::-1, :, z], -1), cmap="gray") axs[2].set_xlabel("x") axs[2].set_ylabel("y") axs[2].invert_yaxis() axs[2].axhline(y=y, color="#d95f02", linestyle="--") axs[2].axvline(x=x, color="#1b9e77", linestyle="--") axs[2].set_title(f"z-axis, slice: {z}", c="#7570b3") plt.show() if export_path: plt.savefig(export_path, bbox_inches="tight", pad_inches=0.5)