Skip to content

utils

Script for utility functions needed by the LVAE model.

Interpolate #

Bases: Module

Wrapper for torch.nn.functional.interpolate.

Source code in src/careamics/models/lvae/utils.py
class Interpolate(nn.Module):
    """Wrapper for torch.nn.functional.interpolate."""

    def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
        super().__init__()
        assert (size is None) == (scale is not None)
        self.size = size
        self.scale = scale
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale,
            mode=self.mode,
            align_corners=self.align_corners,
        )
        return out

StableExponential #

Here, the idea is that everything is done on the tensor which you've given in the constructor. when exp() is called, what that means is that we want to compute self._tensor.exp() when log() is called, we want to compute torch.log(self._tensor.exp())

What is done here is that definition of exp() has been changed. This, naturally, has changed the result of log. but the log is still the mathematical log, that is, it takes the math.log() on whatever comes out of exp().

Source code in src/careamics/models/lvae/utils.py
class StableExponential:
    """
    Here, the idea is that everything is done on the tensor which you've given in the constructor.
    when exp() is called, what that means is that we want to compute self._tensor.exp()
    when log() is called, we want to compute torch.log(self._tensor.exp())

    What is done here is that definition of exp() has been changed. This, naturally, has changed the result of log.
    but the log is still the mathematical log, that is, it takes the math.log() on whatever comes out of exp().
    """  # TODO document

    def __init__(self, tensor):
        self._raw_tensor = tensor
        posneg_dic = self.posneg_separation(self._raw_tensor)
        self.pos_f, self.neg_f = posneg_dic["filter"]
        self.pos_data, self.neg_data = posneg_dic["value"]

    def posneg_separation(self, tensor):
        pos = tensor > 0
        pos_tensor = torch.clip(tensor, min=0)

        neg = tensor <= 0
        neg_tensor = torch.clip(tensor, max=0)

        return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}

    def exp(self):
        return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f

    def log(self):
        """
        Note that if you have the output from exp(). You could simply apply torch.log() on it and that should give
        identical numbers.
        """
        return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f

log() #

Note that if you have the output from exp(). You could simply apply torch.log() on it and that should give identical numbers.

Source code in src/careamics/models/lvae/utils.py
def log(self):
    """
    Note that if you have the output from exp(). You could simply apply torch.log() on it and that should give
    identical numbers.
    """
    return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f

StableLogVar #

Source code in src/careamics/models/lvae/utils.py
class StableLogVar:

    def __init__(self, logvar, enable_stable=True, var_eps=1e-6):
        """
        Args:
            var_eps: var() has this minimum value. # TODO document !
        """
        self._lv = logvar
        self._enable_stable = enable_stable
        self._eps = var_eps

    def get(self):
        if self._enable_stable is False:
            return self._lv

        return torch.log(self.get_var())

    def get_var(self):
        if self._enable_stable is False:
            return torch.exp(self._lv)
        return StableExponential(self._lv).exp() + self._eps

    def get_std(self):
        return torch.sqrt(self.get_var())

    def centercrop_to_size(self, size):
        if self._lv.shape[-1] == size:
            return

        diff = self._lv.shape[-1] - size
        assert diff > 0 and diff % 2 == 0
        self._lv = F.center_crop(self._lv, (size, size))

__init__(logvar, enable_stable=True, var_eps=1e-06) #

Args: var_eps: var() has this minimum value. # TODO document !

Source code in src/careamics/models/lvae/utils.py
def __init__(self, logvar, enable_stable=True, var_eps=1e-6):
    """
    Args:
        var_eps: var() has this minimum value. # TODO document !
    """
    self._lv = logvar
    self._enable_stable = enable_stable
    self._eps = var_eps

allow_numpy(func) #

All optional arguments are passed as is. positional arguments are checked. if they are numpy array, they are converted to torch Tensor.

Source code in src/careamics/models/lvae/utils.py
def allow_numpy(func):
    """
    All optional arguments are passed as is. positional arguments are checked. if they are numpy array,
    they are converted to torch Tensor.
    """

    def numpy_wrapper(*args, **kwargs):
        new_args = []
        for arg in args:
            if isinstance(arg, np.ndarray):
                arg = torch.Tensor(arg)
            new_args.append(arg)
        new_args = tuple(new_args)

        output = func(*new_args, **kwargs)
        return output

    return numpy_wrapper

crop_img_tensor(x, size) #

Crops a tensor. Crops a tensor of shape (batch, channels, h, w) to a desired height and width given by a tuple. Args: x (torch.Tensor): Input image size (list or tuple): Desired size (height, width)

Returns:

Type Description
The cropped tensor
Source code in src/careamics/models/lvae/utils.py
def crop_img_tensor(x, size) -> torch.Tensor:
    """Crops a tensor.
    Crops a tensor of shape (batch, channels, h, w) to a desired height and width
    given by a tuple.
    Args:
        x (torch.Tensor): Input image
        size (list or tuple): Desired size (height, width)

    Returns
    -------
        The cropped tensor
    """
    return _pad_crop_img(x, size, "crop")

free_bits_kl(kl, free_bits, batch_average=False, eps=1e-06) #

Computes free-bits version of KL divergence. Ensures that the KL doesn't go to zero for any latent dimension. Hence, it contributes to use latent variables more efficiently, leading to better representation learning.

NOTE: Takes in the KL with shape (batch size, layers), returns the KL with free bits (for optimization) with shape (layers,), which is the average free-bits KL per layer in the current batch. If batch_average is False (default), the free bits are per layer and per batch element. Otherwise, the free bits are still per layer, but are assigned on average to the whole batch. In both cases, the batch average is returned, so it's simply a matter of doing mean(clamp(KL)) or clamp(mean(KL)).

Args: kl (torch.Tensor) free_bits (float) batch_average (bool, optional)) eps (float, optional)

Returns:

Type Description
The KL with free bits
Source code in src/careamics/models/lvae/utils.py
def free_bits_kl(
    kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
) -> torch.Tensor:
    """
    Computes free-bits version of KL divergence.
    Ensures that the KL doesn't go to zero for any latent dimension.
    Hence, it contributes to use latent variables more efficiently,
    leading to better representation learning.

    NOTE:
    Takes in the KL with shape (batch size, layers), returns the KL with
    free bits (for optimization) with shape (layers,), which is the average
    free-bits KL per layer in the current batch.
    If batch_average is False (default), the free bits are per layer and
    per batch element. Otherwise, the free bits are still per layer, but
    are assigned on average to the whole batch. In both cases, the batch
    average is returned, so it's simply a matter of doing mean(clamp(KL))
    or clamp(mean(KL)).

    Args:
        kl (torch.Tensor)
        free_bits (float)
        batch_average (bool, optional))
        eps (float, optional)

    Returns
    -------
        The KL with free bits
    """
    assert kl.dim() == 2
    if free_bits < eps:
        return kl.mean(0)
    if batch_average:
        return kl.mean(0).clamp(min=free_bits)
    return kl.clamp(min=free_bits).mean(0)

kl_normal_mc(z, p_mulv, q_mulv) #

One-sample estimation of element-wise KL between two diagonal multivariate normal distributions. Any number of dimensions, broadcasting supported (be careful). :param z: :param p_mulv: :param q_mulv: :return:

Source code in src/careamics/models/lvae/utils.py
def kl_normal_mc(z, p_mulv, q_mulv):
    """
    One-sample estimation of element-wise KL between two diagonal
    multivariate normal distributions. Any number of dimensions,
    broadcasting supported (be careful).
    :param z:
    :param p_mulv:
    :param q_mulv:
    :return:
    """
    assert isinstance(p_mulv, tuple)
    assert isinstance(q_mulv, tuple)
    p_mu, p_lv = p_mulv
    q_mu, q_lv = q_mulv

    p_std = p_lv.get_std()
    q_std = q_lv.get_std()

    p_distrib = Normal(p_mu.get(), p_std)
    q_distrib = Normal(q_mu.get(), q_std)
    return q_distrib.log_prob(z) - p_distrib.log_prob(z)

pad_img_tensor(x, size) #

Pads a tensor

Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.

Returns:

Type Description
The padded tensor
Source code in src/careamics/models/lvae/utils.py
def pad_img_tensor(x: torch.Tensor, size: Iterable[int]) -> torch.Tensor:
    """Pads a tensor

    Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.

    Parameters
    ----------
        x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
        size (list or tuple): Desired size  ([Z*], Y*, X*)

    Returns
    -------
        The padded tensor
    """
    return _pad_crop_img(x, size, "pad")