utils
Script for utility functions needed by the LVAE model.
Interpolate
#
Bases: Module
Wrapper for torch.nn.functional.interpolate.
Source code in src/careamics/models/lvae/utils.py
StableExponential
#
Here, the idea is that everything is done on the tensor which you've given in the constructor. when exp() is called, what that means is that we want to compute self._tensor.exp() when log() is called, we want to compute torch.log(self._tensor.exp())
What is done here is that definition of exp() has been changed. This, naturally, has changed the result of log. but the log is still the mathematical log, that is, it takes the math.log() on whatever comes out of exp().
Source code in src/careamics/models/lvae/utils.py
log()
#
Note that if you have the output from exp(). You could simply apply torch.log() on it and that should give identical numbers.
StableLogVar
#
Source code in src/careamics/models/lvae/utils.py
__init__(logvar, enable_stable=True, var_eps=1e-06)
#
Args: var_eps: var() has this minimum value. # TODO document !
allow_numpy(func)
#
All optional arguments are passed as is. positional arguments are checked. if they are numpy array, they are converted to torch Tensor.
Source code in src/careamics/models/lvae/utils.py
crop_img_tensor(x, size)
#
Crops a tensor. Crops a tensor of shape (batch, channels, h, w) to a desired height and width given by a tuple. Args: x (torch.Tensor): Input image size (list or tuple): Desired size (height, width)
Returns:
Type | Description |
---|---|
The cropped tensor | |
Source code in src/careamics/models/lvae/utils.py
free_bits_kl(kl, free_bits, batch_average=False, eps=1e-06)
#
Computes free-bits version of KL divergence. Ensures that the KL doesn't go to zero for any latent dimension. Hence, it contributes to use latent variables more efficiently, leading to better representation learning.
NOTE: Takes in the KL with shape (batch size, layers), returns the KL with free bits (for optimization) with shape (layers,), which is the average free-bits KL per layer in the current batch. If batch_average is False (default), the free bits are per layer and per batch element. Otherwise, the free bits are still per layer, but are assigned on average to the whole batch. In both cases, the batch average is returned, so it's simply a matter of doing mean(clamp(KL)) or clamp(mean(KL)).
Args: kl (torch.Tensor) free_bits (float) batch_average (bool, optional)) eps (float, optional)
Returns:
Type | Description |
---|---|
The KL with free bits | |
Source code in src/careamics/models/lvae/utils.py
kl_normal_mc(z, p_mulv, q_mulv)
#
One-sample estimation of element-wise KL between two diagonal multivariate normal distributions. Any number of dimensions, broadcasting supported (be careful). :param z: :param p_mulv: :param q_mulv: :return:
Source code in src/careamics/models/lvae/utils.py
pad_img_tensor(x, size)
#
Pads a tensor
Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.
Returns:
Type | Description |
---|---|
The padded tensor | |