Likelihoods
Script containing modules for defining different likelihood functions (as nn.Module).
GaussianLikelihood
Bases: LikelihoodModule
A specialized LikelihoodModule for Gaussian likelihood.
Specifically, in the LVAE model, the likelihood is defined as: p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
__init__(predict_logvar=None, logvar_lowerbound=None)
Constructor.
Parameters:
distr_params(x)
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
Parameters:
-
x(Tensor) –The input tensor to the likelihood module, i.e., the output the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
predict_logvaris not None, or (B, C, [Z], Y, X) otherwise.
forward(input_, x)
Parameters:
-
input_(Tensor) –The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).
-
x(Union[Tensor, None]) –The target tensor. If None, the log-likelihood is not computed.
get_mean_lv(x)
Given the output of the top-down pass, compute the mean and log-variance of the Gaussian distribution defining the likelihood.
Parameters:
-
x(Tensor) –The input tensor to the likelihood module, i.e., the output of the top-down pass.
Returns:
-
tuple of (torch.tensor, optional torch.tensor)–The first element of the tuple is the mean, the second element is the log-variance. If the attribute
predict_logvarisNonethen the second element will beNone.
log_likelihood(x, params)
Compute Gaussian log-likelihood
Parameters:
-
x(Tensor) –The target tensor. Shape is (B, C, [Z], Y, X).
-
params(dict[str, Union[Tensor, None]]) –The tensors obtained by chunking the output of the top-down pass, here used as parameters of the Gaussian distribution.
Returns:
-
Tensor–The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
LikelihoodModule
Bases: Module
The base class for all likelihood modules. It defines the fundamental structure and methods for specialized likelihood models.
forward(input_, x)
Parameters:
-
input_(Tensor) –The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).
-
x(Union[Tensor, None]) –The target tensor. If None, the log-likelihood is not computed.
NoiseModelLikelihood
Bases: LikelihoodModule
__init__(noise_model)
Constructor.
Parameters:
-
noiseModel–The noise model instance used to compute the likelihood.
forward(input_, x)
Parameters:
-
input_(Tensor) –The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).
-
x(Union[Tensor, None]) –The target tensor. If None, the log-likelihood is not computed.
log_likelihood(x, params)
Compute the log-likelihood given the parameters params obtained
from the reconstruction tensor and the target tensor x.
Parameters:
-
x(Tensor) –The target tensor. Shape is (B, C, [Z], Y, X).
-
params(dict[str, Tensor]) –The tensors obtained from output of the top-down pass. Here, "mean" correspond to the whole output, while logvar is
None.
Returns:
-
Tensor–The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
set_data_stats(data_mean, data_std)
Set the data mean and std for denormalization.
TODO check this !!
Parameters:
likelihood_factory(config, noise_model=None)
Factory function for creating likelihood modules.
Parameters:
-
config(Optional[Union[GaussianLikelihoodConfig, NMLikelihoodConfig]]) –The configuration object for the likelihood module.
-
noise_model(Optional[NoiseModel], default:None) –The noise model instance used to define the
NoiseModelLikelihood.
Returns:
-
Module–The likelihood module.
log_normal(x, mean, logvar)
Compute the log-probability at x of a Gaussian distribution
with parameters (mean, exp(logvar)).
NOTE: In the case of LVAE, the log-likeihood formula becomes: \mathbb{E}{z_1\sim{q\phi}}[\log{p_ heta(x|z_1)}]=-rac{1}{2}(\mathbb{E}{z_1\sim{q\phi}}[\log{2\pi\sigma_{p,0}^2(z_1)}] +\mathbb{E}{z_1\sim{q\phi}}[rac{(x-\mu_{p,0}(z_1))^2}{\sigma_{p,0}^2(z_1)}])
Parameters:
-
x(Tensor) –The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
-
mean(Tensor) –The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
-
logvar(Tensor) –The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.