Lightning Module
CAREamics Lightning module.
FCNModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm_config
|
AlgorithmModel or dict
|
Algorithm configuration. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
PyTorch model. |
loss_func |
Module
|
Loss function. |
optimizer_name |
str
|
Optimizer name. |
optimizer_params |
dict
|
Optimizer parameters. |
lr_scheduler_name |
str
|
Learning rate scheduler name. |
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
| Type | Description |
|---|---|
Any
|
Optimizer and learning rate scheduler. |
forward(x)
predict_step(batch, batch_idx)
training_step(batch, batch_idx)
validation_step(batch, batch_idx)
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tensor
|
Input batch. |
required |
batch_idx
|
Any
|
Batch index. |
required |
VAEModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the a PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm_config
|
Union[VAEAlgorithmConfig, dict]
|
Algorithm configuration. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
PyTorch model. |
loss_func |
Module
|
Loss function. |
optimizer_name |
str
|
Optimizer name. |
optimizer_params |
dict
|
Optimizer parameters. |
lr_scheduler_name |
str
|
Learning rate scheduler name. |
compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)
Compute the PSNR for the current validation batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_output
|
tuple[Tensor, dict[str, Any]]
|
Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary. |
required |
target
|
Tensor
|
Target tensor. |
required |
psnr_func
|
Callable
|
PSNR function to use, by default |
scale_invariant_psnr
|
Returns:
| Type | Description |
|---|---|
list[float]
|
PSNR for each channel in the current batch. |
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
| Type | Description |
|---|---|
Any
|
Optimizer and learning rate scheduler. |
forward(x)
get_reconstructed_tensor(model_outputs)
Get the reconstructed tensor from the LVAE model outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_outputs
|
tuple[Tensor, dict[str, Any]]
|
Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed tensor, i.e., the predicted mean. |
on_validation_epoch_end()
Validation epoch end.
predict_step(batch, batch_idx)
reduce_running_psnr()
Reduce the running PSNR statistics and reset the running PSNR.
Returns:
| Type | Description |
|---|---|
Optional[float]
|
Running PSNR averaged over the different output channels. |
set_data_stats(data_mean, data_std)
training_step(batch, batch_idx)
Training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
tuple[Tensor, Tensor]
|
Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit). |
required |
batch_idx
|
Any
|
Batch index. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Loss value. |
validation_step(batch, batch_idx)
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
tuple[Tensor, Tensor]
|
Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit). |
required |
batch_idx
|
Any
|
Batch index. |
required |
create_careamics_module(algorithm, loss, architecture, use_n2v2=False, struct_n2v_axis='none', struct_n2v_span=5, model_parameters=None, optimizer='Adam', optimizer_parameters=None, lr_scheduler='ReduceLROnPlateau', lr_scheduler_parameters=None)
Create a CAREamics Lightning module.
This function exposes parameters used to create an AlgorithmModel instance, triggering parameters validation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm
|
SupportedAlgorithm or str
|
Algorithm to use for training (see SupportedAlgorithm). |
required |
loss
|
SupportedLoss or str
|
Loss function to use for training (see SupportedLoss). |
required |
architecture
|
SupportedArchitecture or str
|
Model architecture to use for training (see SupportedArchitecture). |
required |
use_n2v2
|
bool
|
Whether to use N2V2 or Noise2Void. |
False
|
struct_n2v_axis
|
"horizontal", "vertical", or "none"
|
Axis of the StructN2V mask. |
"none"
|
struct_n2v_span
|
int
|
Span of the StructN2V mask. |
5
|
model_parameters
|
dict
|
Model parameters to use for training, by default {}. Model parameters are
defined in the relevant |
None
|
optimizer
|
SupportedOptimizer or str
|
Optimizer to use for training, by default "Adam" (see SupportedOptimizer). |
'Adam'
|
optimizer_parameters
|
dict
|
Optimizer parameters to use for training, as defined in |
None
|
lr_scheduler
|
SupportedScheduler or str
|
Learning rate scheduler to use for training, by default "ReduceLROnPlateau" (see SupportedScheduler). |
'ReduceLROnPlateau'
|
lr_scheduler_parameters
|
dict
|
Learning rate scheduler parameters to use for training, as defined in
|
None
|
Returns:
| Type | Description |
|---|---|
CAREamicsModule
|
CAREamics Lightning module. |