Skip to content

Callbacks

Source

Callbacks module.

CareamicsCheckpointInfo

Bases: Callback

Callback to save CAREamics configuration in Lightning checkpoints.

This callback automatically stores CAREamics version, experiment name, and training configuration in the checkpoint file for reproducibility.

Parameters:

Name Type Description Default
careamics_version str

Version of CAREamics used for training.

required
experiment_name str

Name of the experiment.

required
training_config TrainingConfig

Training configuration to store in checkpoint.

required

Attributes:

Name Type Description
careamics_version str

Version of CAREamics used for training.

experiment_name str

Name of the experiment.

training_config TrainingConfig

Training configuration to store in checkpoint.

on_save_checkpoint(trainer, pl_module, checkpoint)

Lightning hook called when saving a checkpoint.

Adds CAREamics configuration to the checkpoint dictionary.

Parameters:

Name Type Description Default
trainer Trainer

Lightning trainer instance.

required
pl_module LightningModule

Lightning module being trained.

required
checkpoint dict[str, Any]

Checkpoint dictionary to modify.

required

DataStatsCallback

Bases: Callback

Callback to update model's data statistics from datamodule.

This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.

setup(trainer, module, stage)

Called when trainer is setting up.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
module LightningModule

Lightning module.

required
stage str

Current stage (fit, validate, test, or predict).

required

HyperParametersCallback

Bases: Callback

Callback allowing saving CAREamics configuration as hyperparameters in the model.

This allows saving the configuration as dictionary in the checkpoints, and loading it subsequently in a CAREamist instance.

Parameters:

Name Type Description Default
config Configuration

CAREamics configuration to be saved as hyperparameter in the model.

required

Attributes:

Name Type Description
config Configuration

CAREamics configuration to be saved as hyperparameter in the model.

on_train_start(trainer, pl_module)

Update the hyperparameters of the model with the configuration on train start.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer, unused.

required
pl_module LightningModule

PyTorch Lightning module.

required

PredictionStoppedException

Bases: Exception

Exception raised when prediction is stopped by external signal.

PredictionWriterCallback

Bases: BasePredictionWriter

A PyTorch Lightning callback to save predictions.

Parameters:

Name Type Description Default
write_strategy WriteStrategy

A strategy for writing predictions.

required
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Attributes:

Name Type Description
write_strategy WriteStrategy

A strategy for writing predictions.

dirpath pathlib.Path, default="predictions"

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

writing_predictions bool

If writing predictions is turned on or off.

from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions') classmethod

Initialize a PredictionWriterCallback from write function parameters.

This will automatically create a WriteStrategy to be passed to the initialization of PredictionWriterCallback.

Parameters:

Name Type Description Default
write_type ('tiff', 'custom')

The data type to save as, includes custom.

"tiff"
tiled bool

Whether the prediction will be tiled or not.

required
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func_kwargs dict of {{str: any}}

Additional keyword arguments to be passed to the save function.

None
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Returns:

Type Description
PredictionWriterCallback

Callback for writing predictions.

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
stage str

Stage of training e.g. 'predict', 'fit', 'validate'.

required

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

The method of prediction is determined by the attribute write_strategy.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
prediction Any

Prediction outputs of batch.

required
batch_indices sequence of Any

Batch indices.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required

ProgressBarCallback

Bases: TQDMProgressBar

Progress bar for training and validation steps.

get_metrics(trainer, pl_module)

Override this to customize the metrics displayed in the progress bar.

Parameters:

Name Type Description Default
trainer Trainer

The trainer object.

required
pl_module LightningModule

The LightningModule object, unused.

required

Returns:

Type Description
dict

A dictionary with the metrics to display in the progress bar.

init_test_tqdm()

Override this to customize the tqdm bar for testing.

Returns:

Type Description
tqdm

A tqdm bar.

init_train_tqdm()

Override this to customize the tqdm bar for training.

Returns:

Type Description
tqdm

A tqdm bar.

init_validation_tqdm()

Override this to customize the tqdm bar for validation.

Returns:

Type Description
tqdm

A tqdm bar.

StopPredictionCallback

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

Name Type Description Default
stop_condition Callable[[], bool]

A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

required

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Check stop condition at the start of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer instance.

required
pl_module LightningModule

Lightning module being used for prediction.

required
batch Any

Current batch of data.

required
batch_idx int

Index of the current batch.

required
dataloader_idx int

Index of the dataloader, by default 0.

0

Raises:

Type Description
PredictionStoppedException

If stop_condition() returns True.

create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Create a write strategy from convenient parameters.

Parameters:

Name Type Description Default
write_type (tiff, custom)

The data type to save as, includes custom.

"tiff"
tiled bool

Whether the prediction will be tiled or not.

required
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func_kwargs dict of {str: any}

Additional keyword arguments to be passed to the save function.

None

Returns:

Type Description
WriteStrategy

A strategy for writing predicions.

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...

The write_func_kwargs will be passed to the write_func doing the following:

write_func(file_path=file_path, img=img, **kwargs)