Likelihoods
Script containing modules for defining different likelihood functions (as nn.Module).
GaussianLikelihood
Bases: LikelihoodModule
A specialized LikelihoodModule for Gaussian likelihood.
Specifically, in the LVAE model, the likelihood is defined as: p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
distr_params(x)
Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The input tensor to the likelihood module, i.e., the output
the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
|
required |
forward(input_, x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_
|
Tensor
|
The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models). |
required |
x
|
Union[Tensor, None]
|
The target tensor. If None, the log-likelihood is not computed. |
required |
get_mean_lv(x)
Given the output of the top-down pass, compute the mean and log-variance of the Gaussian distribution defining the likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The input tensor to the likelihood module, i.e., the output of the top-down pass. |
required |
Returns:
| Type | Description |
|---|---|
tuple of (torch.tensor, optional torch.tensor)
|
The first element of the tuple is the mean, the second element is the
log-variance. If the attribute |
log_likelihood(x, params)
Compute Gaussian log-likelihood
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The target tensor. Shape is (B, C, [Z], Y, X). |
required |
params
|
dict[str, Union[Tensor, None]]
|
The tensors obtained by chunking the output of the top-down pass, here used as parameters of the Gaussian distribution. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The log-likelihood tensor. Shape is (B, C, [Z], Y, X). |
LikelihoodModule
Bases: Module
The base class for all likelihood modules. It defines the fundamental structure and methods for specialized likelihood models.
forward(input_, x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_
|
Tensor
|
The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models). |
required |
x
|
Union[Tensor, None]
|
The target tensor. If None, the log-likelihood is not computed. |
required |
NoiseModelLikelihood
Bases: LikelihoodModule
forward(input_, x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_
|
Tensor
|
The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models). |
required |
x
|
Union[Tensor, None]
|
The target tensor. If None, the log-likelihood is not computed. |
required |
log_likelihood(x, params)
Compute the log-likelihood given the parameters params obtained
from the reconstruction tensor and the target tensor x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The target tensor. Shape is (B, C, [Z], Y, X). |
required |
params
|
dict[str, Tensor]
|
The tensors obtained from output of the top-down pass.
Here, "mean" correspond to the whole output, while logvar is |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The log-likelihood tensor. Shape is (B, C, [Z], Y, X). |
set_data_stats(data_mean, data_std)
Set the data mean and std for denormalization.
TODO check this !!
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_mean
|
Union[ndarray, Tensor]
|
Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting. |
required |
data_std
|
Union[ndarray, Tensor]
|
Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting. |
required |
likelihood_factory(config, noise_model=None)
Factory function for creating likelihood modules.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Optional[Union[GaussianLikelihoodConfig, NMLikelihoodConfig]]
|
The configuration object for the likelihood module. |
required |
noise_model
|
Optional[NoiseModel]
|
The noise model instance used to define the |
None
|
Returns:
| Type | Description |
|---|---|
Module
|
The likelihood module. |
log_normal(x, mean, logvar)
Compute the log-probability at x of a Gaussian distribution
with parameters (mean, exp(logvar)).
NOTE: In the case of LVAE, the log-likeihood formula becomes: \mathbb{E}{z_1\sim{q\phi}}[\log{p_ heta(x|z_1)}]=-rac{1}{2}(\mathbb{E}{z_1\sim{q\phi}}[\log{2\pi\sigma_{p,0}^2(z_1)}] +\mathbb{E}{z_1\sim{q\phi}}[rac{(x-\mu_{p,0}(z_1))^2}{\sigma_{p,0}^2(z_1)}])
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The ground-truth tensor. Shape is (batch, channels, dim1, dim2). |
required |
mean
|
Tensor
|
The inferred mean of distribution. Shape is (batch, channels, dim1, dim2). |
required |
logvar
|
Tensor
|
The inferred log-variance of distribution. Shape has to be either scalar or broadcastable. |
required |