loss_utils
free_bits_kl(kl, free_bits, batch_average=False, eps=1e-06)
#
Compute free-bits version of KL divergence.
This function ensures that the KL doesn't go to zero for any latent dimension. Hence, it contributes to use latent variables more efficiently, leading to better representation learning.
NOTE: Takes in the KL with shape (batch size, layers), returns the KL with free bits (for optimization) with shape (layers,), which is the average free-bits KL per layer in the current batch. If batch_average is False (default), the free bits are per layer and per batch element. Otherwise, the free bits are still per layer, but are assigned on average to the whole batch. In both cases, the batch average is returned, so it's simply a matter of doing mean(clamp(KL)) or clamp(mean(KL)).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kl | Tensor | The KL divergence tensor with shape (batch size, layers). | required |
free_bits | float | The free bits value. Set to 0.0 to disable free bits. | required |
batch_average | bool | Whether to average over the batch before clamping to | False |
eps | float | A small value to avoid numerical instability. | 1e-06 |
Returns:
Type | Description |
---|---|
Tensor | The free-bits version of the KL divergence with shape (layers,). |
Source code in src/careamics/losses/lvae/loss_utils.py
get_kl_weight(kl_annealing, kl_start, kl_annealtime, kl_weight, current_epoch)
#
Compute the weight of the KL loss in case of annealing.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kl_annealing | bool | Whether to use KL annealing. | required |
kl_start | int | The epoch at which to start | required |
kl_annealtime | int | The number of epochs for which annealing is applied. | required |
kl_weight | float | The weight for the KL loss. If | required |
current_epoch | int | The current epoch. | required |