Skip to content

Lightning Module

Source

CAREamics Lightning module.

FCNModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

Name Type Description Default
algorithm_config AlgorithmModel or dict

Algorithm configuration.

required

Attributes:

Name Type Description
model Module

PyTorch model.

loss_func Module

Loss function.

optimizer_name str

Optimizer name.

optimizer_params dict

Optimizer parameters.

lr_scheduler_name str

Learning rate scheduler name.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

Type Description
Any

Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Any

Input tensor.

required

Returns:

Type Description
Any

Output tensor.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Model output.

training_step(batch, batch_idx)

Training step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

VAEModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

Name Type Description Default
algorithm_config Union[VAEAlgorithmConfig, dict]

Algorithm configuration.

required

Attributes:

Name Type Description
model Module

PyTorch model.

loss_func Module

Loss function.

optimizer_name str

Optimizer name.

optimizer_params dict

Optimizer parameters.

lr_scheduler_name str

Learning rate scheduler name.

compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)

Compute the PSNR for the current validation batch.

Parameters:

Name Type Description Default
model_output tuple[Tensor, dict[str, Any]]

Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.

required
target Tensor

Target tensor.

required
psnr_func Callable

PSNR function to use, by default scale_invariant_psnr.

scale_invariant_psnr

Returns:

Type Description
list[float]

PSNR for each channel in the current batch.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

Type Description
Any

Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs.

required

Returns:

Type Description
tuple[Tensor, dict[str, Any]]

A tuple with the output tensor and additional data from the top-down pass.

get_reconstructed_tensor(model_outputs)

Get the reconstructed tensor from the LVAE model outputs.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.

required

Returns:

Type Description
Tensor

Reconstructed tensor, i.e., the predicted mean.

on_validation_epoch_end()

Validation epoch end.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Model output.

reduce_running_psnr()

Reduce the running PSNR statistics and reset the running PSNR.

Returns:

Type Description
Optional[float]

Running PSNR averaged over the different output channels.

set_data_stats(data_mean, data_std)

Set data mean and std for the noise model likelihood.

Parameters:

Name Type Description Default
data_mean float

Mean of the data.

required
data_std float

Standard deviation of the data.

required

training_step(batch, batch_idx)

Training step.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
batch_idx Any

Batch index.

required

create_careamics_module(algorithm, loss, architecture, use_n2v2=False, struct_n2v_axis='none', struct_n2v_span=5, model_parameters=None, optimizer='Adam', optimizer_parameters=None, lr_scheduler='ReduceLROnPlateau', lr_scheduler_parameters=None)

Create a CAREamics Lightning module.

This function exposes parameters used to create an AlgorithmModel instance, triggering parameters validation.

Parameters:

Name Type Description Default
algorithm SupportedAlgorithm or str

Algorithm to use for training (see SupportedAlgorithm).

required
loss SupportedLoss or str

Loss function to use for training (see SupportedLoss).

required
architecture SupportedArchitecture or str

Model architecture to use for training (see SupportedArchitecture).

required
use_n2v2 bool

Whether to use N2V2 or Noise2Void.

False
struct_n2v_axis "horizontal", "vertical", or "none"

Axis of the StructN2V mask.

"none"
struct_n2v_span int

Span of the StructN2V mask.

5
model_parameters dict

Model parameters to use for training, by default {}. Model parameters are defined in the relevant torch.nn.Module class, or Pyddantic model (see careamics.config.architectures).

None
optimizer SupportedOptimizer or str

Optimizer to use for training, by default "Adam" (see SupportedOptimizer).

'Adam'
optimizer_parameters dict

Optimizer parameters to use for training, as defined in torch.optim, by default {}.

None
lr_scheduler SupportedScheduler or str

Learning rate scheduler to use for training, by default "ReduceLROnPlateau" (see SupportedScheduler).

'ReduceLROnPlateau'
lr_scheduler_parameters dict

Learning rate scheduler parameters to use for training, as defined in torch.optim, by default {}.

None

Returns:

Type Description
CAREamicsModule

CAREamics Lightning module.