LVAE Prediction
Module containing pytorch implementations for obtaining predictions from an LVAE.
lvae_predict_mmse_tiled_batch(model, likelihood_obj, input, mmse_count)
Generate the MMSE (minimum mean squared error) prediction, for a given input.
This is calculated from the mean of multiple single sample predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LadderVAE
|
Trained LVAE model. |
required |
likelihood_obj
|
LikelihoodModule
|
Instance of a likelihood class. |
required |
input
|
torch.tensor | tuple of (torch.tensor, Any, ...)
|
Input to generate prediction for. This can include auxilary inputs such as
|
required |
mmse_count
|
int
|
Number of samples to generate to calculate MMSE (minimum mean squared error). |
required |
Returns:
| Type | Description |
|---|---|
tuple of (tuple of (torch.Tensor[Any], Any, ...))
|
A tuple of 3 elements. The first element contains the MMSE prediction, the
second contains the standard deviation of the samples used to create the MMSE
prediction. Finally the last element contains the log-variance of the
likelihood, this will be |
lvae_predict_single_sample(model, likelihood_obj, input)
Generate a single sample prediction from an LVAE model, for a given input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LadderVAE
|
Trained LVAE model. |
required |
likelihood_obj
|
LikelihoodModule
|
Instance of a likelihood class. |
required |
input
|
tensor
|
Input to generate prediction for. Expected shape is (S, C, Y, X). |
required |
Returns:
| Type | Description |
|---|---|
tuple of (torch.tensor, optional torch.tensor)
|
The first element is the sample prediction, and the second element is the
log-variance. The log-variance will be None if |
lvae_predict_tiled_batch(model, likelihood_obj, input)
Generate a single sample prediction from an LVAE model, for a given input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LadderVAE
|
Trained LVAE model. |
required |
likelihood_obj
|
LikelihoodModule
|
Instance of a likelihood class. |
required |
input
|
torch.tensor | tuple of (torch.tensor, Any, ...)
|
Input to generate prediction for. This can include auxilary inputs such as
|
required |
Returns:
| Type | Description |
|---|---|
tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
|
The first element is the sample prediction, and the second element is the
log-variance. The log-variance will be None if |