Skip to content

Likelihoods

Source

Script containing modules for defining different likelihood functions (as nn.Module).

GaussianLikelihood

Bases: LikelihoodModule

A specialized LikelihoodModule for Gaussian likelihood.

Specifically, in the LVAE model, the likelihood is defined as: p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)

__init__(predict_logvar=None, logvar_lowerbound=None)

Constructor.

Parameters:

  • predict_logvar (Union[Literal['pixelwise'], None], default: None ) –

    If pixelwise, log-variance is computed for each pixel, else log-variance is not computed. Default is None.

  • logvar_lowerbound (Union[float, None], default: None ) –

    The lowerbound value for log-variance. Default is None.

distr_params(x)

Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.

Parameters:

  • x (Tensor) –

    The input tensor to the likelihood module, i.e., the output the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case predict_logvar is not None, or (B, C, [Z], Y, X) otherwise.

forward(input_, x)

Parameters:

  • input_ (Tensor) –

    The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).

  • x (Union[Tensor, None]) –

    The target tensor. If None, the log-likelihood is not computed.

get_mean_lv(x)

Given the output of the top-down pass, compute the mean and log-variance of the Gaussian distribution defining the likelihood.

Parameters:

  • x (Tensor) –

    The input tensor to the likelihood module, i.e., the output of the top-down pass.

Returns:

  • tuple of (torch.tensor, optional torch.tensor)

    The first element of the tuple is the mean, the second element is the log-variance. If the attribute predict_logvar is None then the second element will be None.

log_likelihood(x, params)

Compute Gaussian log-likelihood

Parameters:

  • x (Tensor) –

    The target tensor. Shape is (B, C, [Z], Y, X).

  • params (dict[str, Union[Tensor, None]]) –

    The tensors obtained by chunking the output of the top-down pass, here used as parameters of the Gaussian distribution.

Returns:

  • Tensor

    The log-likelihood tensor. Shape is (B, C, [Z], Y, X).

LikelihoodModule

Bases: Module

The base class for all likelihood modules. It defines the fundamental structure and methods for specialized likelihood models.

forward(input_, x)

Parameters:

  • input_ (Tensor) –

    The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).

  • x (Union[Tensor, None]) –

    The target tensor. If None, the log-likelihood is not computed.

NoiseModelLikelihood

Bases: LikelihoodModule

__init__(noise_model)

Constructor.

Parameters:

  • noiseModel

    The noise model instance used to compute the likelihood.

forward(input_, x)

Parameters:

  • input_ (Tensor) –

    The output of the top-down pass (e.g., reconstructed image in HDN, or the unmixed images in 'Split' models).

  • x (Union[Tensor, None]) –

    The target tensor. If None, the log-likelihood is not computed.

log_likelihood(x, params)

Compute the log-likelihood given the parameters params obtained from the reconstruction tensor and the target tensor x.

Parameters:

  • x (Tensor) –

    The target tensor. Shape is (B, C, [Z], Y, X).

  • params (dict[str, Tensor]) –

    The tensors obtained from output of the top-down pass. Here, "mean" correspond to the whole output, while logvar is None.

Returns:

  • Tensor

    The log-likelihood tensor. Shape is (B, C, [Z], Y, X).

set_data_stats(data_mean, data_std)

Set the data mean and std for denormalization.

TODO check this !!

Parameters:

  • data_mean (Union[ndarray, Tensor]) –

    Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.

  • data_std (Union[ndarray, Tensor]) –

    Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.

likelihood_factory(config, noise_model=None)

Factory function for creating likelihood modules.

Parameters:

Returns:

  • Module

    The likelihood module.

log_normal(x, mean, logvar)

Compute the log-probability at x of a Gaussian distribution with parameters (mean, exp(logvar)).

NOTE: In the case of LVAE, the log-likeihood formula becomes: \mathbb{E}{z_1\sim{q\phi}}[\log{p_ heta(x|z_1)}]=- rac{1}{2}(\mathbb{E}{z_1\sim{q\phi}}[\log{2\pi\sigma_{p,0}^2(z_1)}] +\mathbb{E}{z_1\sim{q\phi}}[ rac{(x-\mu_{p,0}(z_1))^2}{\sigma_{p,0}^2(z_1)}])

Parameters:

  • x (Tensor) –

    The ground-truth tensor. Shape is (batch, channels, dim1, dim2).

  • mean (Tensor) –

    The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).

  • logvar (Tensor) –

    The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.