stop_prediction_callback
Callback for stopping prediction based on external condition.
PredictionStoppedException #
StopPredictionCallback #
Bases: Callback
PyTorch Lightning callback to stop prediction based on external condition.
This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stop_condition | Callable[[], bool] | A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch. | required |
Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
__init__(stop_condition) #
Initialize the callback with a stop condition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stop_condition | Callable[[], bool] | Function that returns True when prediction should stop. | required |
Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0) #
Check stop condition at the start of each prediction batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer | Trainer | PyTorch Lightning trainer instance. | required |
pl_module | LightningModule | Lightning module being used for prediction. | required |
batch | Any | Current batch of data. | required |
batch_idx | int | Index of the current batch. | required |
dataloader_idx | int | Index of the dataloader, by default 0. | 0 |
Raises:
| Type | Description |
|---|---|
PredictionStoppedException | If stop_condition() returns True. |