Callbacks
Callbacks module.
CareamicsCheckpointInfo
Bases: Callback
Callback to save CAREamics configuration in Lightning checkpoints.
This callback automatically stores CAREamics version, experiment name, and training configuration in the checkpoint file for reproducibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
careamics_version
|
str
|
Version of CAREamics used for training. |
required |
experiment_name
|
str
|
Name of the experiment. |
required |
training_config
|
TrainingConfig
|
Training configuration to store in checkpoint. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
careamics_version |
str
|
Version of CAREamics used for training. |
experiment_name |
str
|
Name of the experiment. |
training_config |
TrainingConfig
|
Training configuration to store in checkpoint. |
on_save_checkpoint(trainer, pl_module, checkpoint)
Lightning hook called when saving a checkpoint.
Adds CAREamics configuration to the checkpoint dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
Lightning trainer instance. |
required |
pl_module
|
LightningModule
|
Lightning module being trained. |
required |
checkpoint
|
dict[str, Any]
|
Checkpoint dictionary to modify. |
required |
DataStatsCallback
Bases: Callback
Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.
setup(trainer, module, stage)
Called when trainer is setting up.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
module
|
LightningModule
|
Lightning module. |
required |
stage
|
str
|
Current stage (fit, validate, test, or predict). |
required |
HyperParametersCallback
Bases: Callback
Callback allowing saving CAREamics configuration as hyperparameters in the model.
This allows saving the configuration as dictionary in the checkpoints, and loading it subsequently in a CAREamist instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Configuration
|
CAREamics configuration to be saved as hyperparameter in the model. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Configuration
|
CAREamics configuration to be saved as hyperparameter in the model. |
on_train_start(trainer, pl_module)
Update the hyperparameters of the model with the configuration on train start.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer, unused. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
PredictionStoppedException
Bases: Exception
Exception raised when prediction is stopped by external signal.
PredictionWriterCallback
Bases: BasePredictionWriter
A PyTorch Lightning callback to save predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_strategy
|
WriteStrategy
|
A strategy for writing predictions. |
required |
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Attributes:
| Name | Type | Description |
|---|---|---|
write_strategy |
WriteStrategy
|
A strategy for writing predictions. |
dirpath |
pathlib.Path, default="predictions"
|
The path to the directory where prediction outputs will be saved. If
|
writing_predictions |
bool
|
If writing predictions is turned on or off. |
from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions')
classmethod
Initialize a PredictionWriterCallback from write function parameters.
This will automatically create a WriteStrategy to be passed to the
initialization of PredictionWriterCallback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
('tiff', 'custom')
|
The data type to save as, includes custom. |
"tiff"
|
tiled
|
bool
|
Whether the prediction will be tiled or not. |
required |
write_func
|
WriteFunc
|
If a known |
None
|
write_extension
|
str
|
If a known |
None
|
write_func_kwargs
|
dict of {{str: any}}
|
Additional keyword arguments to be passed to the save function. |
None
|
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Returns:
| Type | Description |
|---|---|
PredictionWriterCallback
|
Callback for writing predictions. |
setup(trainer, pl_module, stage)
Create the prediction output directory when predict begins.
Called when fit, validate, test, predict, or tune begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
stage
|
str
|
Stage of training e.g. 'predict', 'fit', 'validate'. |
required |
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
Write predictions at the end of a batch.
The method of prediction is determined by the attribute write_strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
prediction
|
Any
|
Prediction outputs of |
required |
batch_indices
|
sequence of Any
|
Batch indices. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
ProgressBarCallback
Bases: TQDMProgressBar
Progress bar for training and validation steps.
get_metrics(trainer, pl_module)
Override this to customize the metrics displayed in the progress bar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
The trainer object. |
required |
pl_module
|
LightningModule
|
The LightningModule object, unused. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary with the metrics to display in the progress bar. |
init_test_tqdm()
Override this to customize the tqdm bar for testing.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
init_train_tqdm()
Override this to customize the tqdm bar for training.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
init_validation_tqdm()
Override this to customize the tqdm bar for validation.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
StopPredictionCallback
Bases: Callback
PyTorch Lightning callback to stop prediction based on external condition.
This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stop_condition
|
Callable[[], bool]
|
A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch. |
required |
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)
Check stop condition at the start of each prediction batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer instance. |
required |
pl_module
|
LightningModule
|
Lightning module being used for prediction. |
required |
batch
|
Any
|
Current batch of data. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
dataloader_idx
|
int
|
Index of the dataloader, by default 0. |
0
|
Raises:
| Type | Description |
|---|---|
PredictionStoppedException
|
If stop_condition() returns True. |
create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)
Create a write strategy from convenient parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
(tiff, custom)
|
The data type to save as, includes custom. |
"tiff"
|
tiled
|
bool
|
Whether the prediction will be tiled or not. |
required |
write_func
|
WriteFunc
|
If a known |
None
|
write_extension
|
str
|
If a known |
None
|
write_func_kwargs
|
dict of {str: any}
|
Additional keyword arguments to be passed to the save function. |
None
|
Returns:
| Type | Description |
|---|---|
WriteStrategy
|
A strategy for writing predicions. |