Skip to content

psnr

PSNR metrics compatible with torchmetrics.

SIPSNR #

Bases: Metric

Scale Invariant PSNR metric using a global data range.

By default, the metric is averaged over channels, but it can also be computed for a specific channel by setting output_channel to the desired channel index.

Adapted from juglab/ScaleInvPSNR, this version of PSNR rescales the predictions and ground truth to have similar range, then computes the PSNR using a global data range accumulated over all batches. For a scale-invariant version of PSNR with per-sample data range, see SampleSIPSNR.

Scale invariance can be turned off using use_scale_invariance=False, in which case the metric is equivalent to torchmetrics.image.PeakSignalNoiseRatio, with data_range equal to the difference between the global max and min over all batches.

Note that as opposed to torchmetrics.image.PeakSignalNoiseRatio, this implementation is compatible with 3D and can be computed on a single channel.

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input images.

required
output_channel int

Channel to compute the metric on. If -1, the metric is computed on all channels.

-1
use_scale_invariance bool

Whether to use scale invariance. If False, the metric is equivalent to PSNR with global data range.

True
**kwargs Any

Additional keyword arguments passed to the parent Metric class.

{}

Attributes:

Name Type Description
glob_max Tensor

Global maximum values for each channel.

glob_min Tensor

Global minimum values for each channel.

mse_log Tensor

Logarithm of the mean squared error summed over batches.

total Tensor

Total number of samples processed.

Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
class SIPSNR(Metric):
    """Scale Invariant PSNR metric using a global data range.

    By default, the metric is averaged over channels, but it can also be computed for a
    specific channel by setting `output_channel` to the desired channel index.

    Adapted from juglab/ScaleInvPSNR, this version of PSNR rescales the predictions and
    ground truth to have similar range, then computes the PSNR using a global data range
    accumulated over all batches. For a scale-invariant version of PSNR with per-sample
    data range, see `SampleSIPSNR`.

    Scale invariance can be turned off using `use_scale_invariance=False`, in which case
    the metric is equivalent to `torchmetrics.image.PeakSignalNoiseRatio`, with
    `data_range` equal to the difference between the global max and min over all
    batches.

    Note that as opposed to `torchmetrics.image.PeakSignalNoiseRatio`, this
    implementation is compatible with 3D and can be computed on a single channel.

    Parameters
    ----------
    n_channels : int
        Number of channels in the input images.
    output_channel : int, default=-1
        Channel to compute the metric on. If -1, the metric is computed on all channels.
    use_scale_invariance : bool
        Whether to use scale invariance. If False, the metric is equivalent to PSNR with
        global data range.
    **kwargs : Any
        Additional keyword arguments passed to the parent Metric class.

    Attributes
    ----------
    glob_max : Tensor
        Global maximum values for each channel.
    glob_min : Tensor
        Global minimum values for each channel.
    mse_log : Tensor
        Logarithm of the mean squared error summed over batches.
    total : Tensor
        Total number of samples processed.
    """

    is_differentiable: bool | None = True
    higher_is_better: bool | None = True
    full_state_update: bool = True

    def __init__(
        self,
        n_channels: int,
        output_channel: int = -1,
        use_scale_invariance: bool = True,
        **kwargs: Any,
    ):
        """Initialize a global scale invariant PSNR metric.

        Parameters
        ----------
        n_channels : int
            Number of channels in the input images.
        output_channel : int, default=-1
            Channel to compute the metric on. If -1, the metric is computed on all
            channels.
        use_scale_invariance : bool, default=True
            Whether to use scale invariance. If False, the metric is equivalent to PSNR
            with global data range.
        **kwargs : Any
            Additional keyword arguments passed to the parent Metric class.
        """
        super().__init__(**kwargs)

        self.eps = torch.finfo(torch.float32).eps
        self.output_channel = output_channel
        self.use_scale_invariance = use_scale_invariance

        if self.output_channel != -1 and (
            self.output_channel < 0 or self.output_channel >= n_channels
        ):
            raise ValueError(
                f"Invalid `output_channel` value ({self.output_channel}), must be equal"
                f" to -1 to compute an average over all channels, or between 0 and "
                f"{n_channels - 1} to compute the metric on a single channel."
            )

        self.add_state(
            "glob_max",
            default=tensor([float("-inf") for _ in range(n_channels)]),
            dist_reduce_fx="max",
        )
        self.add_state(
            "glob_min",
            default=tensor([float("inf") for _ in range(n_channels)]),
            dist_reduce_fx="min",
        )
        self.add_state(
            "mse_log",
            default=tensor([0.0 for _ in range(n_channels)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "total",
            default=tensor([0.0 for _ in range(n_channels)]),
            dist_reduce_fx="sum",
        )

    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update the metric states with values computed from a new batch.

        Parameters
        ----------
        preds : Tensor
            Predicted images tensor of shape (B, C, (Z), Y, X).
        target : Tensor
            Ground truth images tensor of shape (B, C, (Z), Y, X).
        """
        batch_size = target.shape[0]
        shape = target.shape
        dims = tuple(range(2, len(shape)))

        # compute min/max of the batches and channels
        batch_min = torch.amin(target, dim=(0,) + dims)
        batch_max = torch.amax(target, dim=(0,) + dims)
        # implementation note: in the original function (`scale_invariant_psnr`), the
        # `data_range` is divided by `np.std(gt)`. This mathematically cancels out with
        # the scaling applied directly to `gt` (`_zero_mean(gt) / np.std(gt)`). Here,
        # we compute a global data range but still consider that the same scaling
        # factor is applied to the data range and `gt` (either the sample std, or
        # a global one), so that they cancel out.

        # fix range of gt and prediction
        if self.use_scale_invariance:
            tar_rescaled, pred_rescaled = _normalise_range(target, preds)
        else:
            tar_rescaled, pred_rescaled = target, preds

        # compute mse
        mse = torch.mean((tar_rescaled - pred_rescaled) ** 2 + self.eps, dim=dims)

        # update states
        self.glob_max: torch.Tensor = torch.maximum(self.glob_max, batch_max)
        self.glob_min: torch.Tensor = torch.minimum(self.glob_min, batch_min)
        self.mse_log: torch.Tensor = self.mse_log + torch.log10(mse).sum(dim=0)
        self.total: torch.Tensor = self.total + batch_size

    def compute(self) -> Tensor:
        """Compute the final metric value.

        Returns
        -------
        torch.Tensor
            Tensor of length C containing the computed PSNR for each channel.
        """
        glob_data_range = self.glob_max - self.glob_min + self.eps
        psnr = 10 * (torch.log10(glob_data_range**2) - self.mse_log / self.total)

        if self.output_channel == -1:
            return torch.mean(psnr)
        else:
            return psnr[self.output_channel]

__init__(n_channels, output_channel=-1, use_scale_invariance=True, **kwargs) #

Initialize a global scale invariant PSNR metric.

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input images.

required
output_channel int

Channel to compute the metric on. If -1, the metric is computed on all channels.

-1
use_scale_invariance bool

Whether to use scale invariance. If False, the metric is equivalent to PSNR with global data range.

True
**kwargs Any

Additional keyword arguments passed to the parent Metric class.

{}
Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def __init__(
    self,
    n_channels: int,
    output_channel: int = -1,
    use_scale_invariance: bool = True,
    **kwargs: Any,
):
    """Initialize a global scale invariant PSNR metric.

    Parameters
    ----------
    n_channels : int
        Number of channels in the input images.
    output_channel : int, default=-1
        Channel to compute the metric on. If -1, the metric is computed on all
        channels.
    use_scale_invariance : bool, default=True
        Whether to use scale invariance. If False, the metric is equivalent to PSNR
        with global data range.
    **kwargs : Any
        Additional keyword arguments passed to the parent Metric class.
    """
    super().__init__(**kwargs)

    self.eps = torch.finfo(torch.float32).eps
    self.output_channel = output_channel
    self.use_scale_invariance = use_scale_invariance

    if self.output_channel != -1 and (
        self.output_channel < 0 or self.output_channel >= n_channels
    ):
        raise ValueError(
            f"Invalid `output_channel` value ({self.output_channel}), must be equal"
            f" to -1 to compute an average over all channels, or between 0 and "
            f"{n_channels - 1} to compute the metric on a single channel."
        )

    self.add_state(
        "glob_max",
        default=tensor([float("-inf") for _ in range(n_channels)]),
        dist_reduce_fx="max",
    )
    self.add_state(
        "glob_min",
        default=tensor([float("inf") for _ in range(n_channels)]),
        dist_reduce_fx="min",
    )
    self.add_state(
        "mse_log",
        default=tensor([0.0 for _ in range(n_channels)]),
        dist_reduce_fx="sum",
    )
    self.add_state(
        "total",
        default=tensor([0.0 for _ in range(n_channels)]),
        dist_reduce_fx="sum",
    )

compute() #

Compute the final metric value.

Returns:

Type Description
Tensor

Tensor of length C containing the computed PSNR for each channel.

Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def compute(self) -> Tensor:
    """Compute the final metric value.

    Returns
    -------
    torch.Tensor
        Tensor of length C containing the computed PSNR for each channel.
    """
    glob_data_range = self.glob_max - self.glob_min + self.eps
    psnr = 10 * (torch.log10(glob_data_range**2) - self.mse_log / self.total)

    if self.output_channel == -1:
        return torch.mean(psnr)
    else:
        return psnr[self.output_channel]

update(preds, target) #

Update the metric states with values computed from a new batch.

Parameters:

Name Type Description Default
preds Tensor

Predicted images tensor of shape (B, C, (Z), Y, X).

required
target Tensor

Ground truth images tensor of shape (B, C, (Z), Y, X).

required
Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric states with values computed from a new batch.

    Parameters
    ----------
    preds : Tensor
        Predicted images tensor of shape (B, C, (Z), Y, X).
    target : Tensor
        Ground truth images tensor of shape (B, C, (Z), Y, X).
    """
    batch_size = target.shape[0]
    shape = target.shape
    dims = tuple(range(2, len(shape)))

    # compute min/max of the batches and channels
    batch_min = torch.amin(target, dim=(0,) + dims)
    batch_max = torch.amax(target, dim=(0,) + dims)
    # implementation note: in the original function (`scale_invariant_psnr`), the
    # `data_range` is divided by `np.std(gt)`. This mathematically cancels out with
    # the scaling applied directly to `gt` (`_zero_mean(gt) / np.std(gt)`). Here,
    # we compute a global data range but still consider that the same scaling
    # factor is applied to the data range and `gt` (either the sample std, or
    # a global one), so that they cancel out.

    # fix range of gt and prediction
    if self.use_scale_invariance:
        tar_rescaled, pred_rescaled = _normalise_range(target, preds)
    else:
        tar_rescaled, pred_rescaled = target, preds

    # compute mse
    mse = torch.mean((tar_rescaled - pred_rescaled) ** 2 + self.eps, dim=dims)

    # update states
    self.glob_max: torch.Tensor = torch.maximum(self.glob_max, batch_max)
    self.glob_min: torch.Tensor = torch.minimum(self.glob_min, batch_min)
    self.mse_log: torch.Tensor = self.mse_log + torch.log10(mse).sum(dim=0)
    self.total: torch.Tensor = self.total + batch_size

SampleSIPSNR #

Bases: Metric

Scale Invariant PSNR metric with per-sample data range.

By default, the metric is averaged over channels, but it can also be computed for a specific channel by setting output_channel to the desired channel index.

Adapted from juglab/ScaleInvPSNR, this version of PSNR rescales the predictions and ground truth to have similar range, then computes the PSNR using each patch's data range.

Scale invariance can be turned off using use_scale_invariance=False, in which case the metric is equivalent to torchmetrics.image.PeakSignalNoiseRatio, with data_range equal to the difference between each patch's max and min, for each patch, then averaged.

Note that as opposed to torchmetrics.image.PeakSignalNoiseRatio, this implementation is compatible with 3D and multi-channel images.

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input images.

required
output_channel int

Channel to compute the metric on. If -1, the metric is computed on all channels.

-1
use_scale_invariance bool

Whether to use scale invariance. If False, the metric is equivalent to PSNR with per-sample data range.

True
**kwargs Any

Additional keyword arguments passed to the parent Metric class.

{}

Attributes:

Name Type Description
psnr_sum Tensor

Sum of PSNR values for each channel.

total Tensor

Total number of samples processed.

Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
class SampleSIPSNR(Metric):
    """Scale Invariant PSNR metric with per-sample data range.

    By default, the metric is averaged over channels, but it can also be computed for a
    specific channel by setting `output_channel` to the desired channel index.

    Adapted from juglab/ScaleInvPSNR, this version of PSNR rescales the predictions and
    ground truth to have similar range, then computes the PSNR using each patch's data
    range.

    Scale invariance can be turned off using `use_scale_invariance=False`, in which case
    the metric is equivalent to `torchmetrics.image.PeakSignalNoiseRatio`, with
    `data_range` equal to the difference between each patch's max and min, for each
    patch, then averaged.

    Note that as opposed to `torchmetrics.image.PeakSignalNoiseRatio`, this
    implementation is compatible with 3D and multi-channel images.

    Parameters
    ----------
    n_channels : int
        Number of channels in the input images.
    output_channel : int, default=-1
        Channel to compute the metric on. If -1, the metric is computed on all channels.
    use_scale_invariance : bool
        Whether to use scale invariance. If False, the metric is equivalent to PSNR with
        per-sample data range.
    **kwargs : Any
        Additional keyword arguments passed to the parent Metric class.

    Attributes
    ----------
    psnr_sum : Tensor
        Sum of PSNR values for each channel.
    total : Tensor
        Total number of samples processed.
    """

    is_differentiable: bool | None = True
    higher_is_better: bool | None = True
    full_state_update: bool = False

    def __init__(
        self,
        n_channels: int,
        output_channel: int = -1,
        use_scale_invariance: bool = True,
        **kwargs: Any,
    ):
        """Initialize a per-sample scale invariant PSNR metric.

        Parameters
        ----------
        n_channels : int
            Number of channels in the input images.
        output_channel : int, default=-1
            Channel to compute the metric on. If -1, the metric is computed on all
            channels.
        use_scale_invariance : bool, default=True
            Whether to use scale invariance. If False, the metric is equivalent to PSNR
            with per-sample data range.
        **kwargs : Any
            Additional keyword arguments passed to the parent Metric class.
        """
        super().__init__(**kwargs)

        self.eps = torch.finfo(torch.float32).eps
        self.output_channel = output_channel
        self.use_scale_invariance = use_scale_invariance

        if self.output_channel != -1 and (
            self.output_channel < 0 or self.output_channel >= n_channels
        ):
            raise ValueError(
                f"Invalid `output_channel` value ({self.output_channel}), must be equal"
                f" to -1 to compute an average over all channels, or between 0 and "
                f"{n_channels - 1} to compute the metric on a single channel."
            )

        self.add_state(
            "psnr_sum",
            default=tensor([0.0 for _ in range(n_channels)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "total",
            default=tensor([0.0 for _ in range(n_channels)]),
            dist_reduce_fx="sum",
        )

    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update the metric states with values computed from a new batch.

        Parameters
        ----------
        preds : Tensor
            Predicted images tensor of shape (B, C, (Z), Y, X).
        target : Tensor
            Ground truth images tensor of shape (B, C, (Z), Y, X).
        """
        batch_size = target.shape[0]
        shape = target.shape
        dims = tuple(range(2, len(shape)))

        # compute min/max of the batches and channels
        batch_min = torch.amin(target, dim=dims)
        batch_max = torch.amax(target, dim=dims)
        data_range = batch_max - batch_min + self.eps
        # implementation note: in the original function (`scale_invariant_psnr`), the
        # `data_range` is divided by `np.std(gt)`. This mathematically cancels out with
        # the scaling applied directly to `gt` (`_zero_mean(gt) / np.std(gt)`).

        # normalize range of gt and prediction
        if self.use_scale_invariance:
            tar_rescaled, pred_rescaled = _normalise_range(target, preds)
        else:
            tar_rescaled, pred_rescaled = target, preds

        # compute mse
        mse = torch.mean((tar_rescaled - pred_rescaled) ** 2 + self.eps, dim=dims)

        # update states
        self.psnr_sum: torch.Tensor = self.psnr_sum + torch.sum(
            10 * torch.log10(data_range**2 / mse), dim=0
        )
        self.total: torch.Tensor = self.total + batch_size

    def compute(self) -> Tensor:
        """Compute the final metric value.

        Returns
        -------
        torch.Tensor
            Tensor of length C containing the computed PSNR for each channel.
        """
        psnr = self.psnr_sum / self.total

        if self.output_channel == -1:
            return torch.mean(psnr)
        else:
            return psnr[self.output_channel]

__init__(n_channels, output_channel=-1, use_scale_invariance=True, **kwargs) #

Initialize a per-sample scale invariant PSNR metric.

Parameters:

Name Type Description Default
n_channels int

Number of channels in the input images.

required
output_channel int

Channel to compute the metric on. If -1, the metric is computed on all channels.

-1
use_scale_invariance bool

Whether to use scale invariance. If False, the metric is equivalent to PSNR with per-sample data range.

True
**kwargs Any

Additional keyword arguments passed to the parent Metric class.

{}
Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def __init__(
    self,
    n_channels: int,
    output_channel: int = -1,
    use_scale_invariance: bool = True,
    **kwargs: Any,
):
    """Initialize a per-sample scale invariant PSNR metric.

    Parameters
    ----------
    n_channels : int
        Number of channels in the input images.
    output_channel : int, default=-1
        Channel to compute the metric on. If -1, the metric is computed on all
        channels.
    use_scale_invariance : bool, default=True
        Whether to use scale invariance. If False, the metric is equivalent to PSNR
        with per-sample data range.
    **kwargs : Any
        Additional keyword arguments passed to the parent Metric class.
    """
    super().__init__(**kwargs)

    self.eps = torch.finfo(torch.float32).eps
    self.output_channel = output_channel
    self.use_scale_invariance = use_scale_invariance

    if self.output_channel != -1 and (
        self.output_channel < 0 or self.output_channel >= n_channels
    ):
        raise ValueError(
            f"Invalid `output_channel` value ({self.output_channel}), must be equal"
            f" to -1 to compute an average over all channels, or between 0 and "
            f"{n_channels - 1} to compute the metric on a single channel."
        )

    self.add_state(
        "psnr_sum",
        default=tensor([0.0 for _ in range(n_channels)]),
        dist_reduce_fx="sum",
    )
    self.add_state(
        "total",
        default=tensor([0.0 for _ in range(n_channels)]),
        dist_reduce_fx="sum",
    )

compute() #

Compute the final metric value.

Returns:

Type Description
Tensor

Tensor of length C containing the computed PSNR for each channel.

Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def compute(self) -> Tensor:
    """Compute the final metric value.

    Returns
    -------
    torch.Tensor
        Tensor of length C containing the computed PSNR for each channel.
    """
    psnr = self.psnr_sum / self.total

    if self.output_channel == -1:
        return torch.mean(psnr)
    else:
        return psnr[self.output_channel]

update(preds, target) #

Update the metric states with values computed from a new batch.

Parameters:

Name Type Description Default
preds Tensor

Predicted images tensor of shape (B, C, (Z), Y, X).

required
target Tensor

Ground truth images tensor of shape (B, C, (Z), Y, X).

required
Source code in src/careamics/lightning/dataset_ng/metrics/psnr.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric states with values computed from a new batch.

    Parameters
    ----------
    preds : Tensor
        Predicted images tensor of shape (B, C, (Z), Y, X).
    target : Tensor
        Ground truth images tensor of shape (B, C, (Z), Y, X).
    """
    batch_size = target.shape[0]
    shape = target.shape
    dims = tuple(range(2, len(shape)))

    # compute min/max of the batches and channels
    batch_min = torch.amin(target, dim=dims)
    batch_max = torch.amax(target, dim=dims)
    data_range = batch_max - batch_min + self.eps
    # implementation note: in the original function (`scale_invariant_psnr`), the
    # `data_range` is divided by `np.std(gt)`. This mathematically cancels out with
    # the scaling applied directly to `gt` (`_zero_mean(gt) / np.std(gt)`).

    # normalize range of gt and prediction
    if self.use_scale_invariance:
        tar_rescaled, pred_rescaled = _normalise_range(target, preds)
    else:
        tar_rescaled, pred_rescaled = target, preds

    # compute mse
    mse = torch.mean((tar_rescaled - pred_rescaled) ** 2 + self.eps, dim=dims)

    # update states
    self.psnr_sum: torch.Tensor = self.psnr_sum + torch.sum(
        10 * torch.log10(data_range**2 / mse), dim=0
    )
    self.total: torch.Tensor = self.total + batch_size