Skip to content

metrics

Metrics utilities for LVAE training and evaluation.

RunningPSNR #

Track the running PSNR over validation batches.

Source code in src/careamics/lvae_training/metrics.py
class RunningPSNR:
    """Track the running PSNR over validation batches."""

    def __init__(self) -> None:
        self.mse_sum: Tensor
        self.count: float
        self.max_value: float | None
        self.min_value: float | None
        self.reset()

    def reset(self) -> None:
        """Reset accumulated statistics."""
        self.mse_sum = torch.tensor(0.0)
        self.count = 0.0
        self.max_value = None
        self.min_value = None

    def update(self, rec: Tensor, tar: Tensor) -> None:
        """Update statistics with a batch of reconstructed and target images."""
        ins_max = torch.max(tar).item()
        ins_min = torch.min(tar).item()
        if self.max_value is None:
            self.max_value = ins_max
            self.min_value = ins_min
        else:
            self.max_value = max(self.max_value, ins_max)
            self.min_value = min(self.min_value, ins_min)

        mse = (rec - tar) ** 2
        elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
        self.mse_sum = self.mse_sum + torch.nansum(elementwise_mse)
        invalid_elements = int(torch.sum(torch.isnan(elementwise_mse)).item())
        self.count += float(len(elementwise_mse) - invalid_elements)

    def get(self) -> Tensor | None:
        """Return the current PSNR."""
        if (
            self.count == 0
            or self.max_value is None
            or self.min_value is None
            or torch.isnan(self.mse_sum)
        ):
            return None
        rmse = torch.sqrt(self.mse_sum / self.count)
        return 20 * torch.log10((self.max_value - self.min_value) / rmse)

get() #

Return the current PSNR.

Source code in src/careamics/lvae_training/metrics.py
def get(self) -> Tensor | None:
    """Return the current PSNR."""
    if (
        self.count == 0
        or self.max_value is None
        or self.min_value is None
        or torch.isnan(self.mse_sum)
    ):
        return None
    rmse = torch.sqrt(self.mse_sum / self.count)
    return 20 * torch.log10((self.max_value - self.min_value) / rmse)

reset() #

Reset accumulated statistics.

Source code in src/careamics/lvae_training/metrics.py
def reset(self) -> None:
    """Reset accumulated statistics."""
    self.mse_sum = torch.tensor(0.0)
    self.count = 0.0
    self.max_value = None
    self.min_value = None

update(rec, tar) #

Update statistics with a batch of reconstructed and target images.

Source code in src/careamics/lvae_training/metrics.py
def update(self, rec: Tensor, tar: Tensor) -> None:
    """Update statistics with a batch of reconstructed and target images."""
    ins_max = torch.max(tar).item()
    ins_min = torch.min(tar).item()
    if self.max_value is None:
        self.max_value = ins_max
        self.min_value = ins_min
    else:
        self.max_value = max(self.max_value, ins_max)
        self.min_value = min(self.min_value, ins_min)

    mse = (rec - tar) ** 2
    elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
    self.mse_sum = self.mse_sum + torch.nansum(elementwise_mse)
    invalid_elements = int(torch.sum(torch.isnan(elementwise_mse)).item())
    self.count += float(len(elementwise_mse) - invalid_elements)

PSNR(gt, pred, range_=None) #

Compute PSNR for tensors shaped as (batch, H, W).

Source code in src/careamics/lvae_training/metrics.py
@allow_numpy
def PSNR(
    gt: Tensor,
    pred: Tensor,
    range_: Tensor | None = None,
) -> Tensor:
    """Compute PSNR for tensors shaped as (batch, H, W)."""
    if len(gt.shape) != 3:
        msg = "Images must be in shape: (batch, H, W)"
        raise ValueError(msg)
    gt_flat = gt.view(len(gt), -1)
    pred_flat = pred.view(len(gt), -1)
    return _psnr_internal(gt_flat, pred_flat, range_=range_)

RangeInvariantPsnr(gt, pred) #

Compute range-invariant PSNR for grayscale images.

Source code in src/careamics/lvae_training/metrics.py
@allow_numpy
def RangeInvariantPsnr(gt: Tensor, pred: Tensor) -> Tensor:
    """Compute range-invariant PSNR for grayscale images."""
    if len(gt.shape) != 3:
        msg = "Images must be in shape: (batch, H, W)"
        raise ValueError(msg)
    gt_flat = gt.view(len(gt), -1)
    pred_flat = pred.view(len(gt), -1)
    range_values = torch.max(gt_flat, dim=1).values - torch.min(gt_flat, dim=1).values
    ra = range_values / torch.std(gt_flat, dim=1)
    gt_norm = zero_mean(gt_flat) / torch.std(gt_flat, dim=1, keepdim=True)
    return _psnr_internal(zero_mean(gt_norm), fix(gt_norm, pred_flat), ra)

avg_psnr(target, prediction) #

Compute mean and standard error of PSNR.

Source code in src/careamics/lvae_training/metrics.py
def avg_psnr(
    target: ArrayCollection,
    prediction: ArrayCollection,
) -> ChannelStats:
    """Compute mean and standard error of PSNR."""
    return _avg_psnr(target, prediction, PSNR)

avg_range_inv_psnr(target, prediction) #

Compute mean and standard error of range-invariant PSNR.

Source code in src/careamics/lvae_training/metrics.py
def avg_range_inv_psnr(
    target: ArrayCollection,
    prediction: ArrayCollection,
) -> ChannelStats:
    """Compute mean and standard error of range-invariant PSNR."""
    return _avg_psnr(target, prediction, RangeInvariantPsnr)

avg_ssim(target, prediction) #

Compute mean and standard deviation of SSIM.

Source code in src/careamics/lvae_training/metrics.py
def avg_ssim(target: np.ndarray, prediction: np.ndarray) -> ChannelStats:
    """Compute mean and standard deviation of SSIM."""
    ssim_values = [
        structural_similarity(
            target[i],
            prediction[i],
            data_range=(target[i].max() - target[i].min()),
        )
        for i in range(len(target))
    ]
    return float(np.mean(ssim_values)), float(np.std(ssim_values))

compute_SE(arr) #

Compute the standard error of the mean.

Source code in src/careamics/lvae_training/metrics.py
def compute_SE(arr: Sequence[float]) -> float:
    """Compute the standard error of the mean."""
    return float(np.std(arr) / np.sqrt(len(arr)))

compute_custom_ssim(gt_, pred_, ssim_obj_dict) #

Compute SSIM using custom per-channel scorers.

Source code in src/careamics/lvae_training/metrics.py
def compute_custom_ssim(
    gt_: Sequence[np.ndarray],
    pred_: Sequence[np.ndarray],
    ssim_obj_dict: dict[int, MicroSSIM | MicroMS3IM],
) -> list[ChannelStats]:
    """Compute SSIM using custom per-channel scorers."""
    ms_ssim_values: dict[int, list[float]] = defaultdict(list)
    channels = gt_[0].shape[-1]
    for i in range(len(gt_)):
        for ch_idx in range(channels):
            tar_tmp = gt_[i][..., ch_idx]
            pred_tmp = pred_[i][..., ch_idx]
            ms_ssim_values[ch_idx].append(
                ssim_obj_dict[ch_idx].score(tar_tmp, pred_tmp)
            )
    return [
        (float(np.mean(ms_ssim_values[i])), compute_SE(ms_ssim_values[i]))
        for i in range(channels)
    ]

compute_masked_psnr(mask, tar1, tar2, pred1, pred2) #

Compute PSNR on masked regions for two target/prediction pairs.

Source code in src/careamics/lvae_training/metrics.py
def compute_masked_psnr(
    mask: np.ndarray,
    tar1: np.ndarray,
    tar2: np.ndarray,
    pred1: np.ndarray,
    pred2: np.ndarray,
) -> tuple[ChannelStats, ChannelStats]:
    """Compute PSNR on masked regions for two target/prediction pairs."""
    mask_bool = mask.astype(bool)[..., 0]
    tmp_tar1 = tar1[mask_bool].reshape((len(tar1), -1, 1))
    tmp_pred1 = pred1[mask_bool].reshape((len(tar1), -1, 1))
    tmp_tar2 = tar2[mask_bool].reshape((len(tar2), -1, 1))
    tmp_pred2 = pred2[mask_bool].reshape((len(tar2), -1, 1))
    psnr1 = avg_range_inv_psnr(tmp_tar1, tmp_pred1)
    psnr2 = avg_range_inv_psnr(tmp_tar2, tmp_pred2)
    return psnr1, psnr2

compute_multiscale_ssim(gt_, pred_, range_invariant=True) #

Compute channel-wise multiscale SSIM.

Source code in src/careamics/lvae_training/metrics.py
def compute_multiscale_ssim(
    gt_: np.ndarray,
    pred_: np.ndarray,
    range_invariant: bool = True,
) -> list[ChannelStats]:
    """Compute channel-wise multiscale SSIM."""
    ms_ssim_values: dict[int, list[float]] = {i: [] for i in range(gt_.shape[-1])}
    for ch_idx in range(gt_.shape[-1]):
        tar_tmp = gt_[..., ch_idx]
        pred_tmp = pred_[..., ch_idx]
        if range_invariant:
            ms_ssim_values[ch_idx] = [
                range_invariant_multiscale_ssim(
                    tar_tmp[i : i + 1],
                    pred_tmp[i : i + 1],
                )
                for i in range(tar_tmp.shape[0])
            ]
        else:
            metric = MultiScaleStructuralSimilarityIndexMeasure(
                data_range=tar_tmp.max() - tar_tmp.min()
            )
            ms_ssim_values[ch_idx] = [
                metric(
                    torch.Tensor(pred_tmp[i : i + 1, None]),
                    torch.Tensor(tar_tmp[i : i + 1, None]),
                ).item()
                for i in range(tar_tmp.shape[0])
            ]

    return [
        (float(np.mean(ms_ssim_values[i])), compute_SE(ms_ssim_values[i]))
        for i in range(gt_.shape[-1])
    ]

compute_stats(highres_data, pred_unnorm, verbose=True) #

Compute PSNR- and SSIM-based metrics on high-SNR data.

Source code in src/careamics/lvae_training/metrics.py
def compute_stats(
    highres_data: Sequence[np.ndarray],
    pred_unnorm: Sequence[np.ndarray],
    verbose: bool = True,
) -> HighSNRDict:
    """Compute PSNR- and SSIM-based metrics on high-SNR data."""
    psnr_list: list[ChannelStats] = []
    microssim_list: list[ChannelStats] = []
    ms3im_list: list[ChannelStats] = []
    ssim_list: list[ChannelStats] = []
    msssim_list: list[ChannelStats] = []

    channel_count = highres_data[0].shape[-1]
    for ch_idx in range(channel_count):
        gt_ch, pred_ch = _get_list_of_images_from_gt_pred(
            list(highres_data),
            list(pred_unnorm),
            ch_idx,
        )
        psnr_list.append(avg_range_inv_psnr(gt_ch, pred_ch))

        microssim_obj = MicroSSIM()
        microssim_obj.fit(gt_ch, pred_ch)
        mssim_scores = [
            microssim_obj.score(gt_ch[i], pred_ch[i]) for i in range(len(gt_ch))
        ]
        microssim_list.append((float(np.mean(mssim_scores)), compute_SE(mssim_scores)))

        m3sim_obj = MicroMS3IM()
        m3sim_obj.fit(gt_ch, pred_ch)
        ms3im_scores = [
            m3sim_obj.score(gt_ch[i], pred_ch[i]) for i in range(len(gt_ch))
        ]
        ms3im_list.append((float(np.mean(ms3im_scores)), compute_SE(ms3im_scores)))

        ssim_scores = [
            structural_similarity(
                gt_ch[i],
                pred_ch[i],
                data_range=gt_ch[i].max() - gt_ch[i].min(),
            )
            for i in range(len(gt_ch))
        ]
        ssim_list.append((float(np.mean(ssim_scores)), compute_SE(ssim_scores)))

        ms_ssim_scores = []
        for i in range(len(gt_ch)):
            metric = MultiScaleStructuralSimilarityIndexMeasure(
                data_range=gt_ch[i].max() - gt_ch[i].min()
            )
            ms_ssim_scores.append(
                metric(
                    torch.Tensor(pred_ch[i][None, None]),
                    torch.Tensor(gt_ch[i][None, None]),
                ).item()
            )
        msssim_list.append((float(np.mean(ms_ssim_scores)), compute_SE(ms_ssim_scores)))

    if verbose:

        def ssim_str(values: ChannelStats) -> str:
            return f"{np.round(values[0], 3):.3f}+-{np.round(values[1], 3):.3f}"

        def psnr_str(values: ChannelStats) -> str:
            return f"{np.round(values[0], 2)}+-{np.round(values[1], 3)}"

        print(
            "PSNR on Highres",
            "\t".join(psnr_str(value) for value in psnr_list),
        )
        print(
            "MicroSSIM on Highres",
            "\t".join(ssim_str(value) for value in microssim_list),
        )
        print(
            "MicroS3IM on Highres",
            "\t".join(ssim_str(value) for value in ms3im_list),
        )
        print(
            "SSIM on Highres",
            "\t".join(ssim_str(value) for value in ssim_list),
        )
        print(
            "MSSSIM on Highres",
            "\t".join(ssim_str(value) for value in msssim_list),
        )

    return {
        "rangeinvpsnr": psnr_list,
        "microssim": microssim_list,
        "ms3im": ms3im_list,
        "ssim": ssim_list,
        "msssim": msssim_list,
    }

fix(gt, x) #

Zero-mean tensors and match prediction range to the ground truth.

Source code in src/careamics/lvae_training/metrics.py
def fix(gt: Tensor, x: Tensor) -> Tensor:
    """Zero-mean tensors and match prediction range to the ground truth."""
    gt_ = zero_mean(gt)
    return fix_range(gt_, zero_mean(x))

fix_range(gt, x) #

Rescale a tensor to match the range of the ground truth.

Source code in src/careamics/lvae_training/metrics.py
def fix_range(gt: Tensor, x: Tensor) -> Tensor:
    """Rescale a tensor to match the range of the ground truth."""
    denom = torch.sum(x * x, dim=1, keepdim=True)
    a = torch.sum(gt * x, dim=1, keepdim=True) / denom
    return x * a

range_invariant_multiscale_ssim(gt_, pred_) #

Compute range-invariant multiscale SSIM for one channel.

Source code in src/careamics/lvae_training/metrics.py
@allow_numpy
def range_invariant_multiscale_ssim(
    gt_: ArrayBatch,
    pred_: ArrayBatch,
) -> float:
    """Compute range-invariant multiscale SSIM for one channel."""
    shape = gt_.shape
    gt_tensor = torch.as_tensor(gt_.reshape((shape[0], -1)))
    pred_tensor = torch.as_tensor(pred_.reshape((shape[0], -1)))
    gt_tensor = zero_mean(gt_tensor)
    pred_tensor = zero_mean(pred_tensor)
    pred_tensor = fix(gt_tensor, pred_tensor)
    pred_tensor = pred_tensor.reshape(shape)
    gt_tensor = gt_tensor.reshape(shape)
    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
        data_range=gt_tensor.max() - gt_tensor.min()
    )
    return ms_ssim(
        torch.as_tensor(pred_tensor[:, None]),
        torch.as_tensor(gt_tensor[:, None]),
    ).item()

zero_mean(x) #

Return a zero-mean tensor along the channel dimension.

Source code in src/careamics/lvae_training/metrics.py
def zero_mean(x: Tensor) -> Tensor:
    """Return a zero-mean tensor along the channel dimension."""
    return x - torch.mean(x, dim=1, keepdim=True)