Lightning Module
Lightning Module for LadderVAE.
LadderVAELight
Bases: LightningModule
global_step
property
Global step.
get_kl_divergence_loss(topdown_layer_data_dict, kl_key='kl')
kl[i] for each i has length batch_size resulting kl shape: (batch_size, layers)
get_kl_divergence_loss_usplit(topdown_layer_data_dict)
get_kl_weight()
KL loss can be weighted depending whether any annealing procedure is used. This function computes the weight of the KL loss in case of annealing.
get_reconstruction_loss(reconstruction, target, input, splitting_mask=None, return_predicted_img=False, likelihood_obj=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reconstruction
|
Tensor
|
|
required |
target
|
Tensor
|
|
required |
input
|
Tensor
|
|
required |
splitting_mask
|
Tensor
|
A boolean tensor that indicates which items to keep for reconstruction loss computation.
If |
None
|
return_predicted_img
|
bool
|
|
False
|
likelihood_obj
|
LikelihoodModule
|
|
None
|
increment_global_step()
Increments global step by 1.