Skip to content

Loss Utils

Source

free_bits_kl(kl, free_bits, batch_average=False, eps=1e-06)

Compute free-bits version of KL divergence.

This function ensures that the KL doesn't go to zero for any latent dimension. Hence, it contributes to use latent variables more efficiently, leading to better representation learning.

NOTE: Takes in the KL with shape (batch size, layers), returns the KL with free bits (for optimization) with shape (layers,), which is the average free-bits KL per layer in the current batch. If batch_average is False (default), the free bits are per layer and per batch element. Otherwise, the free bits are still per layer, but are assigned on average to the whole batch. In both cases, the batch average is returned, so it's simply a matter of doing mean(clamp(KL)) or clamp(mean(KL)).

Parameters:

Name Type Description Default
kl Tensor

The KL divergence tensor with shape (batch size, layers).

required
free_bits float

The free bits value. Set to 0.0 to disable free bits.

required
batch_average bool

Whether to average over the batch before clamping to free_bits.

False
eps float

A small value to avoid numerical instability.

1e-06

Returns:

Type Description
Tensor

The free-bits version of the KL divergence with shape (layers,).

get_kl_weight(kl_annealing, kl_start, kl_annealtime, kl_weight, current_epoch)

Compute the weight of the KL loss in case of annealing.

Parameters:

Name Type Description Default
kl_annealing bool

Whether to use KL annealing.

required
kl_start int

The epoch at which to start

required
kl_annealtime int

The number of epochs for which annealing is applied.

required
kl_weight float

The weight for the KL loss. If None, the weight is computed using annealing, else it is set to a default of 1.

required
current_epoch int

The current epoch.

required