LVAE Prediction
Module containing pytorch implementations for obtaining predictions from an LVAE.
lvae_predict_mmse_tiled_batch(model, likelihood_obj, input, mmse_count)
Generate the MMSE (minimum mean squared error) prediction, for a given input.
This is calculated from the mean of multiple single sample predictions.
Parameters:
-
model(LadderVAE) –Trained LVAE model.
-
likelihood_obj(LikelihoodModule) –Instance of a likelihood class.
-
input(torch.tensor | tuple of (torch.tensor, Any, ...)) –Input to generate prediction for. This can include auxilary inputs such as
TileInformation, but the model input is always the first item of the tuple. Expected shape of the model input is (S, C, Y, X). -
mmse_count(int) –Number of samples to generate to calculate MMSE (minimum mean squared error).
Returns:
-
tuple of (tuple of (torch.Tensor[Any], Any, ...))–A tuple of 3 elements. The first element contains the MMSE prediction, the second contains the standard deviation of the samples used to create the MMSE prediction. Finally the last element contains the log-variance of the likelihood, this will be
Noneiflikelihood.predict_logvarisNone. Any auxillary data included in the input will also be include with all of the MMSE prediction, the standard deviation, and the log-variance.
lvae_predict_single_sample(model, likelihood_obj, input)
Generate a single sample prediction from an LVAE model, for a given input.
Parameters:
-
model(LadderVAE) –Trained LVAE model.
-
likelihood_obj(LikelihoodModule) –Instance of a likelihood class.
-
input(tensor) –Input to generate prediction for. Expected shape is (S, C, Y, X).
Returns:
-
tuple of (torch.tensor, optional torch.tensor)–The first element is the sample prediction, and the second element is the log-variance. The log-variance will be None if
model.predict_logvar is None.
lvae_predict_tiled_batch(model, likelihood_obj, input)
Generate a single sample prediction from an LVAE model, for a given input.
Parameters:
-
model(LadderVAE) –Trained LVAE model.
-
likelihood_obj(LikelihoodModule) –Instance of a likelihood class.
-
input(torch.tensor | tuple of (torch.tensor, Any, ...)) –Input to generate prediction for. This can include auxilary inputs such as
TileInformation, but the model input is always the first item of the tuple. Expected shape of the model input is (S, C, Y, X).
Returns:
-
tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))–The first element is the sample prediction, and the second element is the log-variance. The log-variance will be None if
model.predict_logvar is None. Any auxillary data included in the input will also be include with both the sample prediction and the log-variance.