Lightning
CAREamics PyTorch Lightning modules.
CAREModule
Bases: LightningModule
CAREamics PyTorch Lightning module for CARE algorithm.
Parameters:
-
algorithm_config(CAREAlgorithm, N2NAlgorithm, or dict) –Configuration for the CARE algorithm, either as a CAREAlgorithm/N2NAlgorithm instance or a dictionary.
__init__(algorithm_config)
Instantiate CARE Module.
Parameters:
-
algorithm_config(CAREAlgorithm, N2NAlgorithm, or dict) –Configuration for the CARE algorithm, either as a CAREAlgorithm/N2NAlgorithm instance or a dictionary.
configure_optimizers()
forward(x)
Forward pass.
Parameters:
-
x(Tensor) –Input tensor.
Returns:
-
Tensor–Model output tensor.
on_fit_start()
On fit start hook for CARE module.
Check that training and validation target data have been supplied.
predict_step(batch, batch_idx)
Prediction step for CARE module.
Parameters:
-
batch(ImageRegionData or (ImageRegionData, ImageRegionData)) –A tuple containing the input data and optionally the target data.
-
batch_idx(int) –The index of the current batch in the prediction loop.
Returns:
-
ImageRegionData–The output batch containing the predictions.
training_step(batch, batch_idx)
Training step for CARE module.
Parameters:
-
batch((ImageRegionData, ImageRegionData)) –A tuple containing the input data and the target data.
-
batch_idx(int) –The index of the current batch in the training loop.
Returns:
-
Tensor–The loss value computed for the current batch.
validation_step(batch, batch_idx)
Validation step for CARE module.
Parameters:
-
batch((ImageRegionData, ImageRegionData)) –A tuple containing the input data and the target data.
-
batch_idx(int) –The index of the current batch in the validation loop.
CareamicsDataModule
Bases: LightningDataModule
Data module for Careamics dataset.
Parameters:
-
data_config(DataConfig) –Pydantic model for CAREamics data configuration.
-
train_data(Any, default:None) –Training data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
train_data_target(Any, default:None) –Training data target. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
train_data_mask(Any, default:None) –Training data mask, an optional mask that can be provided to filter regions of the data during training, such as large areas of background. The mask should be a binary image where a 1 indicates a pixel should be included in the training data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
val_data(Any, default:None) –Validation data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
val_data_target(Any, default:None) –Validation data target. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
pred_data(Any, default:None) –Prediction data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
pred_data_target(Any, default:None) –Prediction data target, this may be used for calculating metrics. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
model_constraints(ModelConstraints | None, default:None) –If provided, the data module will validate that the prediction data shape is compatible with the model constraints.
-
loading(ReadFuncLoading | ImageStackLoading | None, default:None) –The type of loading used for custom data.
ReadFuncLoadingis the use of a simple function that will load full images into memory.ImageStackLoadingis for custom chunked or memory-mapped next-generation file formats enabling single patches to be read from disk at a time. If the data type is not customloadingshould beNone.
Attributes:
-
config(DataConfig) –Pydantic model for CAREamics data configuration.
-
data_type(str) –Type of data, one of SupportedData.
-
batch_size(int) –Batch size for the dataloaders.
Raises:
-
ValueError–If at least one of train_data, val_data or pred_data is not provided.
-
ValueError–If input and target data types are not consistent.
__init__(data_config, *, train_data=None, train_data_target=None, train_data_mask=None, val_data=None, val_data_target=None, pred_data=None, pred_data_target=None, model_constraints=None, loading=None)
__init__(data_config: DataConfig | dict[str, Any], *, train_data: InputVar | None = None, train_data_target: InputVar | None = None, train_data_mask: InputVar | None = None, val_data: InputVar | None = None, val_data_target: InputVar | None = None, pred_data: InputVar | None = None, pred_data_target: InputVar | None = None, model_constraints: ModelConstraints | None = None, loading: ReadFuncLoading | None = None) -> None
__init__(data_config: DataConfig | dict[str, Any], *, train_data: Any | None = None, train_data_target: Any | None = None, train_data_mask: Any | None = None, val_data: Any | None = None, val_data_target: Any | None = None, pred_data: Any | None = None, pred_data_target: Any | None = None, model_constraints: ModelConstraints | None = None, loading: ImageStackLoading = ...) -> None
Data module for Careamics dataset initialization.
Create a lightning datamodule that handles creating datasets for training, validation, and prediction.
Parameters:
-
data_config(DataConfig) –Pydantic model for CAREamics data configuration.
-
train_data(Any, default:None) –Training data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
train_data_target(Any, default:None) –Training data target. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
train_data_mask(Any, default:None) –Training data mask, an optional mask that can be provided to filter regions of the data during training, such as large areas of background. The mask should be a binary image where a 1 indicates a pixel should be included in the training data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
val_data(Any, default:None) –Validation data. If not provided,
data_config.n_val_patchespatches will selected from the training data for validation. If customloadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
val_data_target(Any, default:None) –Validation data target. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
pred_data(Any, default:None) –Prediction data. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
pred_data_target(Any, default:None) –Prediction data target, this may be used for calculating metrics. If custom
loadingis provided it can be any type, otherwise it must be apathlib.Path,str,numpy.ndarrayor a sequence of these, or None. -
model_constraints(ModelConstraints, default:None) –If provided, the data module will validate input and target channels and spatial shapes against the model constraints.
-
loading(ReadFuncLoading | ImageStackLoading | None, default:None) –The type of loading used for custom data.
ReadFuncLoadingis the use of a simple function that will load full images into memory.ImageStackLoadingis for custom chunked or memory-mapped next-generation file formats enabling single patches to be read from disk at a time. If the data type is not customloadingshould beNone.
predict_dataloader()
Create a dataloader for prediction.
Returns:
-
DataLoader–Prediction dataloader.
setup(stage)
Setup datasets.
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict. Creates the datasets for a given stage.
Parameters:
-
stage(str) –The stage to set up datasets for. Is either 'fit', 'validate', 'test', or 'predict'.
Raises:
-
NotImplementedError–If stage is not one of "fit", "validate" or "predict".
train_dataloader()
Create a dataloader for training.
Returns:
-
DataLoader–Training dataloader.
val_dataloader()
Create a dataloader for validation.
Returns:
-
DataLoader–Validation dataloader.
ConfigSaverCallback
Bases: Callback
Callback to save CAREamics configuration in Lightning checkpoints.
This callback automatically stores CAREamics version, experiment name, and training configuration in the checkpoint file for reproducibility.
Parameters:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
Attributes:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
__init__(careamics_version, experiment_name, training_config)
Initialize the callback.
Parameters:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
on_save_checkpoint(trainer, pl_module, checkpoint)
DataStatsCallback
Bases: Callback
Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.
setup(trainer, module, stage)
Called when trainer is setting up.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer.
-
module(LightningModule) –Lightning module.
-
stage(str) –Current stage (fit, validate, test, or predict).
ImageStackLoading
dataclass
Loading spec. for a custom image stack loader (chunked / memory-mapped).
image_stack_loader
instance-attribute
A function that loads image data to a sequence of ImageStack objects.
image_stack_loader_kwargs = None
class-attribute
instance-attribute
Additional keyword arguments to pass to the image_stack_loader alongside the
source of the image data.
MicroSplitDataModule
Bases: LightningDataModule
Lightning DataModule for MicroSplit-style datasets.
Matches the interface of TrainDataModule, but internally uses original MicroSplit dataset logic.
Parameters:
-
data_config(MicroSplitDataConfig) –Configuration for the MicroSplit dataset.
-
train_data(str) –Path to training data directory.
-
val_data(str, default:None) –Path to validation data directory.
-
train_data_target(str, default:None) –Path to training target data.
-
val_data_target(str, default:None) –Path to validation target data.
-
read_source_func(Callable, default:None) –Function to read source data.
-
extension_filter(str, default:'') –File extension filter.
-
val_percentage(float, default:0.1) –Percentage of data to use for validation, by default 0.1.
-
val_minimum_split(int, default:5) –Minimum number of samples for validation split, by default 5.
-
use_in_memory(bool, default:True) –Whether to use in-memory dataset, by default True.
__init__(data_config, train_data, val_data=None, train_data_target=None, val_data_target=None, read_source_func=None, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True)
Initialize MicroSplitDataModule.
Parameters:
-
data_config(MicroSplitDataConfig) –Configuration for the MicroSplit dataset.
-
train_data(str) –Path to training data directory.
-
val_data(str, default:None) –Path to validation data directory.
-
train_data_target(str, default:None) –Path to training target data.
-
val_data_target(str, default:None) –Path to validation target data.
-
read_source_func(Callable, default:None) –Function to read source data.
-
extension_filter(str, default:'') –File extension filter.
-
val_percentage(float, default:0.1) –Percentage of data to use for validation, by default 0.1.
-
val_minimum_split(int, default:5) –Minimum number of samples for validation split, by default 5.
-
use_in_memory(bool, default:True) –Whether to use in-memory dataset, by default True.
get_data_stats()
train_dataloader()
Create a dataloader for training.
Returns:
-
DataLoader–Training dataloader.
val_dataloader()
Create a dataloader for validation.
Returns:
-
DataLoader–Validation dataloader.
N2VModule
Bases: LightningModule
CAREamics PyTorch Lightning module for N2V algorithm.
Parameters:
-
algorithm_config(N2VAlgorithm or dict) –Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.
__init__(algorithm_config)
Instantiate N2VModule.
Parameters:
-
algorithm_config(N2VAlgorithm or dict) –Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.
configure_optimizers()
forward(x)
Forward pass.
Parameters:
-
x(Tensor) –Input tensor.
Returns:
-
Tensor–Model output tensor.
on_fit_start()
On fit start hook for N2V module.
predict_step(batch, batch_idx)
Prediction step for N2V model.
Parameters:
-
batch(ImageRegionData or (ImageRegionData, ImageRegionData)) –A tuple containing the input data and optionally the target data.
-
batch_idx(int) –The index of the current batch in the prediction loop.
Returns:
-
ImageRegionData–The output batch containing the predictions.
training_step(batch, batch_idx)
Training step for N2V model.
Parameters:
-
batch(ImageRegionData or (ImageRegionData, ImageRegionData)) –A tuple containing the input data and the target data.
-
batch_idx(int) –The index of the current batch in the training loop.
Returns:
-
Tensor–The loss value for the current training step.
validation_step(batch, batch_idx)
Validation step for N2V model.
Parameters:
-
batch(ImageRegionData or (ImageRegionData, ImageRegionData)) –A tuple containing the input data and the target data.
-
batch_idx(int) –The index of the current batch in the validation loop.
PredictionStoppedException
Bases: Exception
Exception raised when prediction is stopped by external signal.
ProgressBarCallback
Bases: TQDMProgressBar
Progress bar for training and validation steps.
get_metrics(trainer, pl_module)
Override this to customize the metrics displayed in the progress bar.
Parameters:
-
trainer(Trainer) –The trainer object.
-
pl_module(LightningModule) –The LightningModule object, unused.
Returns:
-
dict–A dictionary with the metrics to display in the progress bar.
init_test_tqdm()
Override this to customize the tqdm bar for testing.
Returns:
-
tqdm–A tqdm bar.
init_train_tqdm()
Override this to customize the tqdm bar for training.
Returns:
-
tqdm–A tqdm bar.
init_validation_tqdm()
Override this to customize the tqdm bar for validation.
Returns:
-
tqdm–A tqdm bar.
ReadFuncLoading
dataclass
Loading specification using a custom read function.
extension_filter = ''
class-attribute
instance-attribute
A filter for finding source files using glob-style pattern matching. For example,
to select files with the extension .npy one should use the filter "*.npy".
read_kwargs = None
class-attribute
instance-attribute
Additional keyword arguments to pass to the read_source_func alongside the
file path to the image data.
read_source_func
instance-attribute
A function for reading image data to numpy arrays.
StopPredictionCallback
Bases: Callback
PyTorch Lightning callback to stop prediction based on external condition.
This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.
Parameters:
-
stop_condition(Callable[[], bool]) –A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.
__init__(stop_condition)
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)
Check stop condition at the start of each prediction batch.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer instance.
-
pl_module(LightningModule) –Lightning module being used for prediction.
-
batch(Any) –Current batch of data.
-
batch_idx(int) –Index of the current batch.
-
dataloader_idx(int, default:0) –Index of the dataloader, by default 0.
Raises:
-
PredictionStoppedException–If stop_condition() returns True.
VAEModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the a PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
Attributes:
-
model(Module) –PyTorch model.
-
loss_func(Module) –Loss function.
-
optimizer_name(str) –Optimizer name.
-
optimizer_params(dict) –Optimizer parameters.
-
lr_scheduler_name(str) –Learning rate scheduler name.
__init__(algorithm_config)
compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)
Compute the PSNR for the current validation batch.
Parameters:
-
model_output(tuple[Tensor, dict[str, Any]]) –Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.
-
target(Tensor) –Target tensor.
-
psnr_func(Callable, default:scale_invariant_psnr) –PSNR function to use, by default
scale_invariant_psnr.
Returns:
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
-
Any–Optimizer and learning rate scheduler.
forward(x)
get_reconstructed_tensor(model_outputs)
Get the reconstructed tensor from the LVAE model outputs.
Parameters:
-
model_outputs(tuple[Tensor, dict[str, Any]]) –Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.
Returns:
-
Tensor–Reconstructed tensor, i.e., the predicted mean.
on_validation_epoch_end()
Validation epoch end.
predict_step(batch, batch_idx)
reduce_running_psnr()
Reduce the running PSNR statistics and reset the running PSNR.
Returns:
-
Optional[float]–Running PSNR averaged over the different output channels.
set_data_stats(data_mean, data_std)
training_step(batch, batch_idx)
Training step.
Parameters:
-
batch(tuple[Tensor, Tensor]) –Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
-
batch_idx(Any) –Batch index.
Returns:
-
Any–Loss value.
validation_step(batch, batch_idx)
Validation step.
Parameters:
-
batch(tuple[Tensor, Tensor]) –Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
-
batch_idx(Any) –Batch index.
create_microsplit_predict_datamodule(pred_data, tile_size, batch_size=1, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, data_stats=None, tiling_mode=TilingMode.ShiftBoundary, read_source_func=None, extension_filter='', dataloader_params=None, **dataset_kwargs)
Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
Parameters:
-
pred_data(str or Path or ndarray) –Prediction data, can be a path to a folder, a file or a numpy array.
-
tile_size(tuple) –Size of one tile of data.
-
batch_size(int, default:1) –Batch size for prediction dataloader.
-
num_channels(int, default:2) –Number of channels in the input.
-
depth3D(int, default:1) –Number of slices in 3D.
-
grid_size(tuple, default:None) –Grid size for patch extraction.
-
multiscale_count(int, default:None) –Number of LC scales.
-
data_stats(tuple, default:None) –Data statistics, by default None.
-
tiling_mode(TilingMode, default:ShiftBoundary) –Tiling mode for patch extraction.
-
read_source_func(Callable, default:None) –Function to read the source data.
-
extension_filter(str, default:'') –File extension filter.
-
dataloader_params(dict, default:None) –Parameters for prediction dataloader.
-
**dataset_kwargs–Additional arguments passed to MicroSplitDataConfig.
Returns:
-
MicroSplitPredictDataModule–Configured MicroSplitPredictDataModule instance.
create_microsplit_train_datamodule(train_data, patch_size, batch_size, val_data=None, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, tiling_mode=TilingMode.ShiftBoundary, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True, transforms=None, train_dataloader_params=None, val_dataloader_params=None, **dataset_kwargs)
Create a MicroSplitDataModule for MicroSplit-style datasets.
Parameters:
-
train_data(str) –Path to training data.
-
patch_size(tuple) –Size of one patch of data.
-
batch_size(int) –Batch size for dataloaders.
-
val_data(str, default:None) –Path to validation data.
-
num_channels(int, default:2) –Number of channels in the input.
-
depth3D(int, default:1) –Number of slices in 3D.
-
grid_size(tuple, default:None) –Grid size for patch extraction.
-
multiscale_count(int, default:None) –Number of LC scales.
-
tiling_mode(TilingMode, default:ShiftBoundary) –Tiling mode for patch extraction.
-
extension_filter(str, default:'') –File extension filter.
-
val_percentage(float, default:0.1) –Percentage of training data to use for validation.
-
val_minimum_split(int, default:5) –Minimum number of patches/files for validation split.
-
use_in_memory(bool, default:True) –Use in-memory dataset if possible.
-
transforms(list, default:None) –List of transforms to apply.
-
train_dataloader_params(dict, default:None) –Parameters for training dataloader.
-
val_dataloader_params(dict, default:None) –Parameters for validation dataloader.
-
**dataset_kwargs–Additional arguments passed to DatasetConfig.
Returns:
-
MicroSplitDataModule–Configured MicroSplitDataModule instance.
load_config_from_checkpoint(checkpoint_path)
Load a CAREamics config from a checkpoint.
Some fields, if missing, will be populated by defaults. Namely, version,
training_config and experiment_name.
The default for experiment_name will be "loaded_from_<checkpoint_filename>".
Parameters:
-
checkpoint_path(Path) –Path to the PyTorch Lightning checkpoint file.
Returns:
-
Configuration–A CAREamics configuration object.
Raises:
-
ValueErrors:–If certain required information is not found in the checkpoint.
load_module_from_checkpoint(checkpoint_path)
Load a trained CAREamics module from checkpoint.
Automatically detects the algorithm type from the checkpoint and loads the appropriate module with trained weights.
Parameters:
-
checkpoint_path(Path) –Path to the PyTorch Lightning checkpoint file.
Returns:
-
CAREamicsModule–Lightning module with loaded weights.
Raises:
-
ValueError–If the algorithm type cannot be determined from the checkpoint.