Callbacks
Callbacks module.
ConfigSaverCallback
Bases: Callback
Callback to save CAREamics configuration in Lightning checkpoints.
This callback automatically stores CAREamics version, experiment name, and training configuration in the checkpoint file for reproducibility.
Parameters:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
Attributes:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
__init__(careamics_version, experiment_name, training_config)
Initialize the callback.
Parameters:
-
careamics_version(str) –Version of CAREamics used for training.
-
experiment_name(str) –Name of the experiment.
-
training_config(TrainingConfig) –Training configuration to store in checkpoint.
on_save_checkpoint(trainer, pl_module, checkpoint)
DataStatsCallback
Bases: Callback
Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.
setup(trainer, module, stage)
Called when trainer is setting up.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer.
-
module(LightningModule) –Lightning module.
-
stage(str) –Current stage (fit, validate, test, or predict).
PredictionStoppedException
Bases: Exception
Exception raised when prediction is stopped by external signal.
PredictionWriterCallback
Bases: BasePredictionWriter
PyTorch Lightning callback to save predictions.
A WriteStrategy must be provided at instantiation or later via
set_writing_strategy. This allows passing the callback to the Lightning Trainer
before knowing what writing strategy (e.g. tiling or file type) will be used.
By default the prediction writer is enabled, but it can be disabled by passing
enable_writing=False at instantiation or by calling enable_writing(False).
Subsequently, writing can be re-enabled by calling enable_writing(True).
Parameters:
-
dirpath(Path or str, default:"predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory. -
write_strategy(WriteStrategy or None, default:None) –A strategy for writing predictions.
-
enable_writing(bool, default:True) –If writing predictions should be enabled by default.
Attributes:
-
is_enabled(bool) –Whether writing predictions is enabled.
-
dirpath(pathlib.Path, default="") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory. -
write_strategy(WriteStrategy or None) –A strategy for writing predictions.
__init__(dirpath='', write_strategy=None, enable_writing=True)
Constructor.
A WriteStrategy must be provided at instantiation or later via
set_writing_strategy.
Parameters:
-
dirpath(Path or str, default:"predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory. -
write_strategy(WriteStrategy or None, default:None) –A strategy for writing predictions.
-
enable_writing(bool, default:True) –If writing predictions should be enabled by default.
enable_writing(enable_writing)
Enable or disable writing.
Parameters:
-
enable_writing(bool) –If writing predictions should be enabled.
set_writing_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)
Set the writing strategy.
Must be called before writing predictions.
Parameters:
-
write_type(SupportedWriteType) –The type of writing to perform.
-
tiled(bool) –Whether to write in tiled format.
-
write_func(WriteFunc or None, default:None) –A custom writing function.
-
write_extension(str or None, default:None) –The file extension to use when writing files.
-
write_func_kwargs(dict of str to Any, default:None) –Additional keyword arguments to pass to
write_func.
setup(trainer, pl_module, stage)
Create the prediction output directory when predict begins.
Called when fit, validate, test, predict, or tune begins.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer.
-
pl_module(LightningModule) –PyTorch Lightning module.
-
stage(str) –Stage of training e.g. 'predict', 'fit', 'validate'.
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
Write predictions at the end of a batch.
Writing method is determined by the attribute write_strategy.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer.
-
pl_module(LightningModule) –PyTorch Lightning module.
-
prediction(ImageRegionData) –Prediction outputs of
batch. -
batch_indices(sequence of Any) –Batch indices.
-
batch(ImageRegionData) –Input batch.
-
batch_idx(int) –Batch index.
-
dataloader_idx(int) –Dataloader index.
ProgressBarCallback
Bases: TQDMProgressBar
Progress bar for training and validation steps.
get_metrics(trainer, pl_module)
Override this to customize the metrics displayed in the progress bar.
Parameters:
-
trainer(Trainer) –The trainer object.
-
pl_module(LightningModule) –The LightningModule object, unused.
Returns:
-
dict–A dictionary with the metrics to display in the progress bar.
init_test_tqdm()
Override this to customize the tqdm bar for testing.
Returns:
-
tqdm–A tqdm bar.
init_train_tqdm()
Override this to customize the tqdm bar for training.
Returns:
-
tqdm–A tqdm bar.
init_validation_tqdm()
Override this to customize the tqdm bar for validation.
Returns:
-
tqdm–A tqdm bar.
StopPredictionCallback
Bases: Callback
PyTorch Lightning callback to stop prediction based on external condition.
This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.
Parameters:
-
stop_condition(Callable[[], bool]) –A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.
__init__(stop_condition)
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)
Check stop condition at the start of each prediction batch.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer instance.
-
pl_module(LightningModule) –Lightning module being used for prediction.
-
batch(Any) –Current batch of data.
-
batch_idx(int) –Index of the current batch.
-
dataloader_idx(int, default:0) –Index of the dataloader, by default 0.
Raises:
-
PredictionStoppedException–If stop_condition() returns True.