Skip to content

VAE Lightning Module

Source

CAREamics Lightning module.

VAEModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

  • algorithm_config (Union[VAEAlgorithmConfig, dict]) –

    Algorithm configuration.

Attributes:

  • model (Module) –

    PyTorch model.

  • loss_func (Module) –

    Loss function.

  • optimizer_name (str) –

    Optimizer name.

  • optimizer_params (dict) –

    Optimizer parameters.

  • lr_scheduler_name (str) –

    Learning rate scheduler name.

__init__(algorithm_config)

Lightning module for CAREamics.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

  • algorithm_config (Union[AlgorithmModel, dict]) –

    Algorithm configuration.

compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)

Compute the PSNR for the current validation batch.

Parameters:

  • model_output (tuple[Tensor, dict[str, Any]]) –

    Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.

  • target (Tensor) –

    Target tensor.

  • psnr_func (Callable, default: scale_invariant_psnr ) –

    PSNR function to use, by default scale_invariant_psnr.

Returns:

  • list[float]

    PSNR for each channel in the current batch.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

  • Any

    Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs.

Returns:

  • tuple[Tensor, dict[str, Any]]

    A tuple with the output tensor and additional data from the top-down pass.

get_reconstructed_tensor(model_outputs)

Get the reconstructed tensor from the LVAE model outputs.

Parameters:

  • model_outputs (tuple[Tensor, dict[str, Any]]) –

    Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.

Returns:

  • Tensor

    Reconstructed tensor, i.e., the predicted mean.

on_validation_epoch_end()

Validation epoch end.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

  • batch (Tensor) –

    Input batch.

  • batch_idx (Any) –

    Batch index.

Returns:

  • Any

    Model output.

reduce_running_psnr()

Reduce the running PSNR statistics and reset the running PSNR.

Returns:

  • Optional[float]

    Running PSNR averaged over the different output channels.

set_data_stats(data_mean, data_std)

Set data mean and std for the noise model likelihood.

Parameters:

  • data_mean (float) –

    Mean of the data.

  • data_std (float) –

    Standard deviation of the data.

training_step(batch, batch_idx)

Training step.

Parameters:

  • batch (tuple[Tensor, Tensor]) –

    Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

  • batch_idx (Any) –

    Batch index.

Returns:

  • Any

    Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

  • batch (tuple[Tensor, Tensor]) –

    Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

  • batch_idx (Any) –

    Batch index.