Skip to content

Lightning

Source

CAREamics PyTorch Lightning modules.

CAREModule

Bases: LightningModule

CAREamics PyTorch Lightning module for CARE algorithm.

Parameters:

  • algorithm_config (CAREAlgorithm, N2NAlgorithm, or dict) –

    Configuration for the CARE algorithm, either as a CAREAlgorithm/N2NAlgorithm instance or a dictionary.

__init__(algorithm_config)

Instantiate CARE Module.

Parameters:

  • algorithm_config (CAREAlgorithm, N2NAlgorithm, or dict) –

    Configuration for the CARE algorithm, either as a CAREAlgorithm/N2NAlgorithm instance or a dictionary.

configure_optimizers()

Configure optimizer and learning rate scheduler.

Returns:

  • dict[str, Any]

    A dictionary containing the optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor.

Returns:

  • Tensor

    Model output tensor.

on_fit_start()

On fit start hook for CARE module.

Check that training and validation target data have been supplied.

predict_step(batch, batch_idx)

Prediction step for CARE module.

Parameters:

Returns:

training_step(batch, batch_idx)

Training step for CARE module.

Parameters:

  • batch ((ImageRegionData, ImageRegionData)) –

    A tuple containing the input data and the target data.

  • batch_idx (int) –

    The index of the current batch in the training loop.

Returns:

  • Tensor

    The loss value computed for the current batch.

validation_step(batch, batch_idx)

Validation step for CARE module.

Parameters:

  • batch ((ImageRegionData, ImageRegionData)) –

    A tuple containing the input data and the target data.

  • batch_idx (int) –

    The index of the current batch in the validation loop.

CareamicsDataModule

Bases: LightningDataModule

Data module for Careamics dataset.

Parameters:

  • data_config (DataConfig) –

    Pydantic model for CAREamics data configuration.

  • train_data (Any, default: None ) –

    Training data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • train_data_target (Any, default: None ) –

    Training data target. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • train_data_mask (Any, default: None ) –

    Training data mask, an optional mask that can be provided to filter regions of the data during training, such as large areas of background. The mask should be a binary image where a 1 indicates a pixel should be included in the training data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • val_data (Any, default: None ) –

    Validation data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • val_data_target (Any, default: None ) –

    Validation data target. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • pred_data (Any, default: None ) –

    Prediction data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • pred_data_target (Any, default: None ) –

    Prediction data target, this may be used for calculating metrics. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • model_constraints (ModelConstraints | None, default: None ) –

    If provided, the data module will validate that the prediction data shape is compatible with the model constraints.

  • loading (ReadFuncLoading | ImageStackLoading | None, default: None ) –

    The type of loading used for custom data. ReadFuncLoading is the use of a simple function that will load full images into memory. ImageStackLoading is for custom chunked or memory-mapped next-generation file formats enabling single patches to be read from disk at a time. If the data type is not custom loading should be None.

Attributes:

  • config (DataConfig) –

    Pydantic model for CAREamics data configuration.

  • data_type (str) –

    Type of data, one of SupportedData.

  • batch_size (int) –

    Batch size for the dataloaders.

Raises:

  • ValueError

    If at least one of train_data, val_data or pred_data is not provided.

  • ValueError

    If input and target data types are not consistent.

__init__(data_config, *, train_data=None, train_data_target=None, train_data_mask=None, val_data=None, val_data_target=None, pred_data=None, pred_data_target=None, model_constraints=None, loading=None)

__init__(data_config: DataConfig | dict[str, Any], *, train_data: InputVar | None = None, train_data_target: InputVar | None = None, train_data_mask: InputVar | None = None, val_data: InputVar | None = None, val_data_target: InputVar | None = None, pred_data: InputVar | None = None, pred_data_target: InputVar | None = None, model_constraints: ModelConstraints | None = None, loading: ReadFuncLoading | None = None) -> None
__init__(data_config: DataConfig | dict[str, Any], *, train_data: Any | None = None, train_data_target: Any | None = None, train_data_mask: Any | None = None, val_data: Any | None = None, val_data_target: Any | None = None, pred_data: Any | None = None, pred_data_target: Any | None = None, model_constraints: ModelConstraints | None = None, loading: ImageStackLoading = ...) -> None

Data module for Careamics dataset initialization.

Create a lightning datamodule that handles creating datasets for training, validation, and prediction.

Parameters:

  • data_config (DataConfig) –

    Pydantic model for CAREamics data configuration.

  • train_data (Any, default: None ) –

    Training data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • train_data_target (Any, default: None ) –

    Training data target. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • train_data_mask (Any, default: None ) –

    Training data mask, an optional mask that can be provided to filter regions of the data during training, such as large areas of background. The mask should be a binary image where a 1 indicates a pixel should be included in the training data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • val_data (Any, default: None ) –

    Validation data. If not provided, data_config.n_val_patches patches will selected from the training data for validation. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • val_data_target (Any, default: None ) –

    Validation data target. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • pred_data (Any, default: None ) –

    Prediction data. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • pred_data_target (Any, default: None ) –

    Prediction data target, this may be used for calculating metrics. If custom loading is provided it can be any type, otherwise it must be a pathlib.Path, str, numpy.ndarray or a sequence of these, or None.

  • model_constraints (ModelConstraints, default: None ) –

    If provided, the data module will validate input and target channels and spatial shapes against the model constraints.

  • loading (ReadFuncLoading | ImageStackLoading | None, default: None ) –

    The type of loading used for custom data. ReadFuncLoading is the use of a simple function that will load full images into memory. ImageStackLoading is for custom chunked or memory-mapped next-generation file formats enabling single patches to be read from disk at a time. If the data type is not custom loading should be None.

predict_dataloader()

Create a dataloader for prediction.

Returns:

  • DataLoader

    Prediction dataloader.

setup(stage)

Setup datasets.

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict. Creates the datasets for a given stage.

Parameters:

  • stage (str) –

    The stage to set up datasets for. Is either 'fit', 'validate', 'test', or 'predict'.

Raises:

train_dataloader()

Create a dataloader for training.

Returns:

  • DataLoader

    Training dataloader.

val_dataloader()

Create a dataloader for validation.

Returns:

  • DataLoader

    Validation dataloader.

ConfigSaverCallback

Bases: Callback

Callback to save CAREamics configuration in Lightning checkpoints.

This callback automatically stores CAREamics version, experiment name, and training configuration in the checkpoint file for reproducibility.

Parameters:

  • careamics_version (str) –

    Version of CAREamics used for training.

  • experiment_name (str) –

    Name of the experiment.

  • training_config (TrainingConfig) –

    Training configuration to store in checkpoint.

Attributes:

  • careamics_version (str) –

    Version of CAREamics used for training.

  • experiment_name (str) –

    Name of the experiment.

  • training_config (TrainingConfig) –

    Training configuration to store in checkpoint.

__init__(careamics_version, experiment_name, training_config)

Initialize the callback.

Parameters:

  • careamics_version (str) –

    Version of CAREamics used for training.

  • experiment_name (str) –

    Name of the experiment.

  • training_config (TrainingConfig) –

    Training configuration to store in checkpoint.

on_save_checkpoint(trainer, pl_module, checkpoint)

Lightning hook called when saving a checkpoint.

Adds CAREamics configuration to the checkpoint dictionary.

Parameters:

  • trainer (Trainer) –

    Lightning trainer instance.

  • pl_module (LightningModule) –

    Lightning module being trained.

  • checkpoint (dict[str, Any]) –

    Checkpoint dictionary to modify.

DataStatsCallback

Bases: Callback

Callback to update model's data statistics from datamodule.

This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.

setup(trainer, module, stage)

Called when trainer is setting up.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • module (LightningModule) –

    Lightning module.

  • stage (str) –

    Current stage (fit, validate, test, or predict).

ImageStackLoading dataclass

Loading spec. for a custom image stack loader (chunked / memory-mapped).

image_stack_loader instance-attribute

A function that loads image data to a sequence of ImageStack objects.

image_stack_loader_kwargs = None class-attribute instance-attribute

Additional keyword arguments to pass to the image_stack_loader alongside the source of the image data.

MicroSplitDataModule

Bases: LightningDataModule

Lightning DataModule for MicroSplit-style datasets.

Matches the interface of TrainDataModule, but internally uses original MicroSplit dataset logic.

Parameters:

  • data_config (MicroSplitDataConfig) –

    Configuration for the MicroSplit dataset.

  • train_data (str) –

    Path to training data directory.

  • val_data (str, default: None ) –

    Path to validation data directory.

  • train_data_target (str, default: None ) –

    Path to training target data.

  • val_data_target (str, default: None ) –

    Path to validation target data.

  • read_source_func (Callable, default: None ) –

    Function to read source data.

  • extension_filter (str, default: '' ) –

    File extension filter.

  • val_percentage (float, default: 0.1 ) –

    Percentage of data to use for validation, by default 0.1.

  • val_minimum_split (int, default: 5 ) –

    Minimum number of samples for validation split, by default 5.

  • use_in_memory (bool, default: True ) –

    Whether to use in-memory dataset, by default True.

__init__(data_config, train_data, val_data=None, train_data_target=None, val_data_target=None, read_source_func=None, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True)

Initialize MicroSplitDataModule.

Parameters:

  • data_config (MicroSplitDataConfig) –

    Configuration for the MicroSplit dataset.

  • train_data (str) –

    Path to training data directory.

  • val_data (str, default: None ) –

    Path to validation data directory.

  • train_data_target (str, default: None ) –

    Path to training target data.

  • val_data_target (str, default: None ) –

    Path to validation target data.

  • read_source_func (Callable, default: None ) –

    Function to read source data.

  • extension_filter (str, default: '' ) –

    File extension filter.

  • val_percentage (float, default: 0.1 ) –

    Percentage of data to use for validation, by default 0.1.

  • val_minimum_split (int, default: 5 ) –

    Minimum number of samples for validation split, by default 5.

  • use_in_memory (bool, default: True ) –

    Whether to use in-memory dataset, by default True.

get_data_stats()

Get data statistics.

Returns:

  • tuple[dict, dict]

    A tuple containing two dictionaries: - data_mean: mean values for input and target - data_std: standard deviation values for input and target

train_dataloader()

Create a dataloader for training.

Returns:

  • DataLoader

    Training dataloader.

val_dataloader()

Create a dataloader for validation.

Returns:

  • DataLoader

    Validation dataloader.

N2VModule

Bases: LightningModule

CAREamics PyTorch Lightning module for N2V algorithm.

Parameters:

  • algorithm_config (N2VAlgorithm or dict) –

    Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.

__init__(algorithm_config)

Instantiate N2VModule.

Parameters:

  • algorithm_config (N2VAlgorithm or dict) –

    Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.

configure_optimizers()

Configure optimizer and learning rate scheduler.

Returns:

  • dict[str, Any]

    A dictionary containing the optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor.

Returns:

  • Tensor

    Model output tensor.

on_fit_start()

On fit start hook for N2V module.

predict_step(batch, batch_idx)

Prediction step for N2V model.

Parameters:

Returns:

training_step(batch, batch_idx)

Training step for N2V model.

Parameters:

Returns:

  • Tensor

    The loss value for the current training step.

validation_step(batch, batch_idx)

Validation step for N2V model.

Parameters:

PredictionStoppedException

Bases: Exception

Exception raised when prediction is stopped by external signal.

ProgressBarCallback

Bases: TQDMProgressBar

Progress bar for training and validation steps.

get_metrics(trainer, pl_module)

Override this to customize the metrics displayed in the progress bar.

Parameters:

  • trainer (Trainer) –

    The trainer object.

  • pl_module (LightningModule) –

    The LightningModule object, unused.

Returns:

  • dict

    A dictionary with the metrics to display in the progress bar.

init_test_tqdm()

Override this to customize the tqdm bar for testing.

Returns:

  • tqdm

    A tqdm bar.

init_train_tqdm()

Override this to customize the tqdm bar for training.

Returns:

  • tqdm

    A tqdm bar.

init_validation_tqdm()

Override this to customize the tqdm bar for validation.

Returns:

  • tqdm

    A tqdm bar.

ReadFuncLoading dataclass

Loading specification using a custom read function.

extension_filter = '' class-attribute instance-attribute

A filter for finding source files using glob-style pattern matching. For example, to select files with the extension .npy one should use the filter "*.npy".

read_kwargs = None class-attribute instance-attribute

Additional keyword arguments to pass to the read_source_func alongside the file path to the image data.

read_source_func instance-attribute

A function for reading image data to numpy arrays.

StopPredictionCallback

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

  • stop_condition (Callable[[], bool]) –

    A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

__init__(stop_condition)

Initialize the callback with a stop condition.

Parameters:

  • stop_condition (Callable[[], bool]) –

    Function that returns True when prediction should stop.

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Check stop condition at the start of each prediction batch.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer instance.

  • pl_module (LightningModule) –

    Lightning module being used for prediction.

  • batch (Any) –

    Current batch of data.

  • batch_idx (int) –

    Index of the current batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the dataloader, by default 0.

Raises:

VAEModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

  • algorithm_config (Union[VAEAlgorithmConfig, dict]) –

    Algorithm configuration.

Attributes:

  • model (Module) –

    PyTorch model.

  • loss_func (Module) –

    Loss function.

  • optimizer_name (str) –

    Optimizer name.

  • optimizer_params (dict) –

    Optimizer parameters.

  • lr_scheduler_name (str) –

    Learning rate scheduler name.

__init__(algorithm_config)

Lightning module for CAREamics.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

  • algorithm_config (Union[AlgorithmModel, dict]) –

    Algorithm configuration.

compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)

Compute the PSNR for the current validation batch.

Parameters:

  • model_output (tuple[Tensor, dict[str, Any]]) –

    Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.

  • target (Tensor) –

    Target tensor.

  • psnr_func (Callable, default: scale_invariant_psnr ) –

    PSNR function to use, by default scale_invariant_psnr.

Returns:

  • list[float]

    PSNR for each channel in the current batch.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

  • Any

    Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs.

Returns:

  • tuple[Tensor, dict[str, Any]]

    A tuple with the output tensor and additional data from the top-down pass.

get_reconstructed_tensor(model_outputs)

Get the reconstructed tensor from the LVAE model outputs.

Parameters:

  • model_outputs (tuple[Tensor, dict[str, Any]]) –

    Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.

Returns:

  • Tensor

    Reconstructed tensor, i.e., the predicted mean.

on_validation_epoch_end()

Validation epoch end.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

  • batch (Tensor) –

    Input batch.

  • batch_idx (Any) –

    Batch index.

Returns:

  • Any

    Model output.

reduce_running_psnr()

Reduce the running PSNR statistics and reset the running PSNR.

Returns:

  • Optional[float]

    Running PSNR averaged over the different output channels.

set_data_stats(data_mean, data_std)

Set data mean and std for the noise model likelihood.

Parameters:

  • data_mean (float) –

    Mean of the data.

  • data_std (float) –

    Standard deviation of the data.

training_step(batch, batch_idx)

Training step.

Parameters:

  • batch (tuple[Tensor, Tensor]) –

    Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

  • batch_idx (Any) –

    Batch index.

Returns:

  • Any

    Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

  • batch (tuple[Tensor, Tensor]) –

    Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

  • batch_idx (Any) –

    Batch index.

create_microsplit_predict_datamodule(pred_data, tile_size, batch_size=1, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, data_stats=None, tiling_mode=TilingMode.ShiftBoundary, read_source_func=None, extension_filter='', dataloader_params=None, **dataset_kwargs)

Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.

Parameters:

  • pred_data (str or Path or ndarray) –

    Prediction data, can be a path to a folder, a file or a numpy array.

  • tile_size (tuple) –

    Size of one tile of data.

  • batch_size (int, default: 1 ) –

    Batch size for prediction dataloader.

  • num_channels (int, default: 2 ) –

    Number of channels in the input.

  • depth3D (int, default: 1 ) –

    Number of slices in 3D.

  • grid_size (tuple, default: None ) –

    Grid size for patch extraction.

  • multiscale_count (int, default: None ) –

    Number of LC scales.

  • data_stats (tuple, default: None ) –

    Data statistics, by default None.

  • tiling_mode (TilingMode, default: ShiftBoundary ) –

    Tiling mode for patch extraction.

  • read_source_func (Callable, default: None ) –

    Function to read the source data.

  • extension_filter (str, default: '' ) –

    File extension filter.

  • dataloader_params (dict, default: None ) –

    Parameters for prediction dataloader.

  • **dataset_kwargs

    Additional arguments passed to MicroSplitDataConfig.

Returns:

create_microsplit_train_datamodule(train_data, patch_size, batch_size, val_data=None, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, tiling_mode=TilingMode.ShiftBoundary, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True, transforms=None, train_dataloader_params=None, val_dataloader_params=None, **dataset_kwargs)

Create a MicroSplitDataModule for MicroSplit-style datasets.

Parameters:

  • train_data (str) –

    Path to training data.

  • patch_size (tuple) –

    Size of one patch of data.

  • batch_size (int) –

    Batch size for dataloaders.

  • val_data (str, default: None ) –

    Path to validation data.

  • num_channels (int, default: 2 ) –

    Number of channels in the input.

  • depth3D (int, default: 1 ) –

    Number of slices in 3D.

  • grid_size (tuple, default: None ) –

    Grid size for patch extraction.

  • multiscale_count (int, default: None ) –

    Number of LC scales.

  • tiling_mode (TilingMode, default: ShiftBoundary ) –

    Tiling mode for patch extraction.

  • extension_filter (str, default: '' ) –

    File extension filter.

  • val_percentage (float, default: 0.1 ) –

    Percentage of training data to use for validation.

  • val_minimum_split (int, default: 5 ) –

    Minimum number of patches/files for validation split.

  • use_in_memory (bool, default: True ) –

    Use in-memory dataset if possible.

  • transforms (list, default: None ) –

    List of transforms to apply.

  • train_dataloader_params (dict, default: None ) –

    Parameters for training dataloader.

  • val_dataloader_params (dict, default: None ) –

    Parameters for validation dataloader.

  • **dataset_kwargs

    Additional arguments passed to DatasetConfig.

Returns:

load_config_from_checkpoint(checkpoint_path)

Load a CAREamics config from a checkpoint.

Some fields, if missing, will be populated by defaults. Namely, version, training_config and experiment_name.

The default for experiment_name will be "loaded_from_<checkpoint_filename>".

Parameters:

  • checkpoint_path (Path) –

    Path to the PyTorch Lightning checkpoint file.

Returns:

Raises:

  • ValueErrors:

    If certain required information is not found in the checkpoint.

load_module_from_checkpoint(checkpoint_path)

Load a trained CAREamics module from checkpoint.

Automatically detects the algorithm type from the checkpoint and loads the appropriate module with trained weights.

Parameters:

  • checkpoint_path (Path) –

    Path to the PyTorch Lightning checkpoint file.

Returns:

  • CAREamicsModule

    Lightning module with loaded weights.

Raises:

  • ValueError

    If the algorithm type cannot be determined from the checkpoint.