Lightning Module
Lightning Module for LadderVAE.
LadderVAELight
Bases: LightningModule
global_step
property
Global step.
__init__(config, data_mean, data_std, target_ch)
Here we will do the following: - initialize the model (from LadderVAE class) - initialize the parameters related to the training and loss.
NOTE: Some of the model attributes are defined in the model object itself, while some others will be defined here. Note that all the attributes related to the training and loss that were already defined in the model object are redefined here as Lightning module attributes (e.g., self.some_attr = model.some_attr). The attributes related to the model itself are treated as model attributes (e.g., self.model.some_attr).
NOTE: HC stands for Hard Coded attribute.
get_kl_divergence_loss(topdown_layer_data_dict, kl_key='kl')
kl[i] for each i has length batch_size resulting kl shape: (batch_size, layers)
get_kl_divergence_loss_usplit(topdown_layer_data_dict)
get_kl_weight()
KL loss can be weighted depending whether any annealing procedure is used. This function computes the weight of the KL loss in case of annealing.
get_reconstruction_loss(reconstruction, target, input, splitting_mask=None, return_predicted_img=False, likelihood_obj=None)
Parameters:
-
reconstruction(Tensor) – -
target(Tensor) – -
input(Tensor) – -
splitting_mask(Tensor, default:None) –A boolean tensor that indicates which items to keep for reconstruction loss computation. If
None, all the elements of the items are considered (i.e., the mask is allTrue). -
return_predicted_img(bool, default:False) – -
likelihood_obj(LikelihoodModule, default:None) –
increment_global_step()
Increments global step by 1.