Skip to content

utils

Script for utility functions needed by the LVAE model.

Interpolate #

Bases: Module

Wrapper for torch.nn.functional.interpolate.

Source code in src/careamics/models/lvae/utils.py
class Interpolate(nn.Module):
    """Wrapper for torch.nn.functional.interpolate."""

    def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
        super().__init__()
        assert (size is None) == (scale is not None)
        self.size = size
        self.scale = scale
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale,
            mode=self.mode,
            align_corners=self.align_corners,
        )
        return out

StableExponential #

Class that redefines the definition of exp() to increase numerical stability. Naturally, also the definition of log() must change accordingly. However, it is worth noting that the two operations remain one the inverse of the other, meaning that x = log(exp(x)) and x = exp(log(x)) are always true.

Definition: exp(x) = { exp(x) if x<=0 x+1 if x>0 }

log(x) = {
    x        if x<=0
    log(1+x) if x>0
}

NOTE 1: Within the class everything is done on the tensor given as input to the constructor. Therefore, when exp() is called, self._tensor.exp() is computed. When log() is called, torch.log(self._tensor.exp()) is computed instead.

NOTE 2: Given the output from exp(), torch.log() or the log() method of the class give identical results.

Source code in src/careamics/models/lvae/utils.py
class StableExponential:
    """
    Class that redefines the definition of exp() to increase numerical stability.
    Naturally, also the definition of log() must change accordingly.
    However, it is worth noting that the two operations remain one the inverse of the other,
    meaning that x = log(exp(x)) and x = exp(log(x)) are always true.

    Definition:
        exp(x) = {
            exp(x) if x<=0
            x+1    if x>0
        }

        log(x) = {
            x        if x<=0
            log(1+x) if x>0
        }

    NOTE 1:
        Within the class everything is done on the tensor given as input to the constructor.
        Therefore, when exp() is called, self._tensor.exp() is computed.
        When log() is called, torch.log(self._tensor.exp()) is computed instead.

    NOTE 2:
        Given the output from exp(), torch.log() or the log() method of the class give identical results.
    """

    def __init__(self, tensor):
        self._raw_tensor = tensor
        posneg_dic = self.posneg_separation(self._raw_tensor)
        self.pos_f, self.neg_f = posneg_dic["filter"]
        self.pos_data, self.neg_data = posneg_dic["value"]

    def posneg_separation(self, tensor):
        pos = tensor > 0
        pos_tensor = torch.clip(tensor, min=0)

        neg = tensor <= 0
        neg_tensor = torch.clip(tensor, max=0)

        return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}

    def exp(self):
        return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f

    def log(self):
        return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f

StableLogVar #

Class that provides a numerically stable implementation of Log-Variance. Specifically, it uses the exp() and log() formulas defined in StableExponential class.

Source code in src/careamics/models/lvae/utils.py
class StableLogVar:
    """
    Class that provides a numerically stable implementation of Log-Variance.
    Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
    """

    def __init__(
        self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
    ):
        """
        Constructor.

        Parameters
        ----------
        logvar: torch.Tensor
            The input (true) logvar vector, to be converted in the Stable version.
        enable_stable: bool, optional
            Whether to compute the stable version of log-variance. Default is `True`.
        var_eps: float, optional
            The minimum value attainable by the variance. Default is `1e-6`.
        """
        self._lv = logvar
        self._enable_stable = enable_stable
        self._eps = var_eps

    def get(self) -> torch.Tensor:
        if self._enable_stable is False:
            return self._lv

        return torch.log(self.get_var())

    def get_var(self) -> torch.Tensor:
        """
        Get Variance from Log-Variance.
        """
        if self._enable_stable is False:
            return torch.exp(self._lv)
        return StableExponential(self._lv).exp() + self._eps

    def get_std(self) -> torch.Tensor:
        return torch.sqrt(self.get_var())

    @property
    def is_3D(self) -> bool:
        """Check if the _lv tensor is 3D.

        Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
        """
        return self._lv.dim() == 5

    def centercrop_to_size(self, size: Sequence[int]) -> None:
        """
        Centercrop the log-variance tensor to the desired size.

        Parameters
        ----------
        size: torch.Tensor
            The desired size of the log-variance tensor.
        """
        assert not self.is_3D, "Centercrop is implemented only for 2D tensors."

        if self._lv.shape[-1] == size:
            return

        diff = self._lv.shape[-1] - size
        assert diff > 0 and diff % 2 == 0
        self._lv = F.center_crop(self._lv, (size, size))

is_3D property #

Check if the _lv tensor is 3D.

Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).

__init__(logvar, enable_stable=True, var_eps=1e-06) #

Constructor.

Parameters:

Name Type Description Default
logvar Tensor

The input (true) logvar vector, to be converted in the Stable version.

required
enable_stable bool

Whether to compute the stable version of log-variance. Default is True.

True
var_eps float

The minimum value attainable by the variance. Default is 1e-6.

1e-06
Source code in src/careamics/models/lvae/utils.py
def __init__(
    self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
):
    """
    Constructor.

    Parameters
    ----------
    logvar: torch.Tensor
        The input (true) logvar vector, to be converted in the Stable version.
    enable_stable: bool, optional
        Whether to compute the stable version of log-variance. Default is `True`.
    var_eps: float, optional
        The minimum value attainable by the variance. Default is `1e-6`.
    """
    self._lv = logvar
    self._enable_stable = enable_stable
    self._eps = var_eps

centercrop_to_size(size) #

Centercrop the log-variance tensor to the desired size.

Parameters:

Name Type Description Default
size Sequence[int]

The desired size of the log-variance tensor.

required
Source code in src/careamics/models/lvae/utils.py
def centercrop_to_size(self, size: Sequence[int]) -> None:
    """
    Centercrop the log-variance tensor to the desired size.

    Parameters
    ----------
    size: torch.Tensor
        The desired size of the log-variance tensor.
    """
    assert not self.is_3D, "Centercrop is implemented only for 2D tensors."

    if self._lv.shape[-1] == size:
        return

    diff = self._lv.shape[-1] - size
    assert diff > 0 and diff % 2 == 0
    self._lv = F.center_crop(self._lv, (size, size))

get_var() #

Get Variance from Log-Variance.

Source code in src/careamics/models/lvae/utils.py
def get_var(self) -> torch.Tensor:
    """
    Get Variance from Log-Variance.
    """
    if self._enable_stable is False:
        return torch.exp(self._lv)
    return StableExponential(self._lv).exp() + self._eps

StableMean #

Source code in src/careamics/models/lvae/utils.py
class StableMean:

    def __init__(self, mean):
        self._mean = mean

    def get(self) -> torch.Tensor:
        return self._mean

    @property
    def is_3D(self) -> bool:
        """Check if the _mean tensor is 3D.

        Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
        """
        return self._mean.dim() == 5

    def centercrop_to_size(self, size: Sequence[int]) -> None:
        """Centercrop the mean tensor to the desired size.

        Implemented only in the case of 2D tensors.

        Parameters
        ----------
        size: torch.Tensor
            The desired size of the log-variance tensor.
        """
        assert not self.is_3D, "Centercrop is implemented only for 2D tensors."

        if self._mean.shape[-1] == size:
            return

        diff = self._mean.shape[-1] - size
        assert diff > 0 and diff % 2 == 0
        self._mean = F.center_crop(self._mean, (size, size))

is_3D property #

Check if the _mean tensor is 3D.

Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).

centercrop_to_size(size) #

Centercrop the mean tensor to the desired size.

Implemented only in the case of 2D tensors.

Parameters:

Name Type Description Default
size Sequence[int]

The desired size of the log-variance tensor.

required
Source code in src/careamics/models/lvae/utils.py
def centercrop_to_size(self, size: Sequence[int]) -> None:
    """Centercrop the mean tensor to the desired size.

    Implemented only in the case of 2D tensors.

    Parameters
    ----------
    size: torch.Tensor
        The desired size of the log-variance tensor.
    """
    assert not self.is_3D, "Centercrop is implemented only for 2D tensors."

    if self._mean.shape[-1] == size:
        return

    diff = self._mean.shape[-1] - size
    assert diff > 0 and diff % 2 == 0
    self._mean = F.center_crop(self._mean, (size, size))

allow_numpy(func) #

All optional arguments are passed as is. positional arguments are checked. if they are numpy array, they are converted to torch Tensor.

Source code in src/careamics/models/lvae/utils.py
def allow_numpy(func):
    """
    All optional arguments are passed as is. positional arguments are checked. if they are numpy array,
    they are converted to torch Tensor.
    """

    def numpy_wrapper(*args, **kwargs):
        new_args = []
        for arg in args:
            if isinstance(arg, np.ndarray):
                arg = torch.Tensor(arg)
            new_args.append(arg)
        new_args = tuple(new_args)

        output = func(*new_args, **kwargs)
        return output

    return numpy_wrapper

crop_img_tensor(x, size) #

Crops a tensor. Crops a tensor of shape (batch, channels, h, w) to a desired height and width given by a tuple. Args: x (torch.Tensor): Input image size (list or tuple): Desired size (height, width)

Returns:

Type Description
The cropped tensor
Source code in src/careamics/models/lvae/utils.py
def crop_img_tensor(x, size) -> torch.Tensor:
    """Crops a tensor.
    Crops a tensor of shape (batch, channels, h, w) to a desired height and width
    given by a tuple.
    Args:
        x (torch.Tensor): Input image
        size (list or tuple): Desired size (height, width)

    Returns
    -------
        The cropped tensor
    """
    return _pad_crop_img(x, size, "crop")

kl_normal_mc(z, p_mulv, q_mulv) #

One-sample estimation of element-wise KL between two diagonal multivariate normal distributions. Any number of dimensions, broadcasting supported (be careful). :param z: :param p_mulv: :param q_mulv: :return:

Source code in src/careamics/models/lvae/utils.py
def kl_normal_mc(z, p_mulv, q_mulv):
    """
    One-sample estimation of element-wise KL between two diagonal
    multivariate normal distributions. Any number of dimensions,
    broadcasting supported (be careful).
    :param z:
    :param p_mulv:
    :param q_mulv:
    :return:
    """
    assert isinstance(p_mulv, tuple)
    assert isinstance(q_mulv, tuple)
    p_mu, p_lv = p_mulv
    q_mu, q_lv = q_mulv

    p_std = p_lv.get_std()
    q_std = q_lv.get_std()

    p_distrib = Normal(p_mu.get(), p_std)
    q_distrib = Normal(q_mu.get(), q_std)
    return q_distrib.log_prob(z) - p_distrib.log_prob(z)

pad_img_tensor(x, size) #

Pads a tensor

Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.

Parameters:
x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
size (list or tuple): Desired size  ([Z*], Y*, X*)
Returns:
The padded tensor
Source code in src/careamics/models/lvae/utils.py
def pad_img_tensor(x: torch.Tensor, size: Sequence[int]) -> torch.Tensor:
    """Pads a tensor

    Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.

    Parameters:
    -----------
        x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
        size (list or tuple): Desired size  ([Z*], Y*, X*)

    Returns:
    --------
        The padded tensor
    """
    return _pad_crop_img(x, size, "pad")