Skip to content

Lightning Module

Source

Lightning Module for LadderVAE.

LadderVAELight

Bases: LightningModule

global_step property

Global step.

get_kl_divergence_loss(topdown_layer_data_dict, kl_key='kl')

kl[i] for each i has length batch_size resulting kl shape: (batch_size, layers)

get_kl_divergence_loss_usplit(topdown_layer_data_dict)

get_kl_weight()

KL loss can be weighted depending whether any annealing procedure is used. This function computes the weight of the KL loss in case of annealing.

get_reconstruction_loss(reconstruction, target, input, splitting_mask=None, return_predicted_img=False, likelihood_obj=None)

Parameters:

Name Type Description Default
reconstruction Tensor
required
target Tensor
required
input Tensor
required
splitting_mask Tensor

A boolean tensor that indicates which items to keep for reconstruction loss computation. If None, all the elements of the items are considered (i.e., the mask is all True).

None
return_predicted_img bool
False
likelihood_obj LikelihoodModule
None

increment_global_step()

Increments global step by 1.