Skip to content

Stop Prediction Callback

Source

Callback for stopping prediction based on external condition.

PredictionStoppedException

Bases: Exception

Exception raised when prediction is stopped by external signal.

StopPredictionCallback

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

Name Type Description Default
stop_condition Callable[[], bool]

A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

required

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Check stop condition at the start of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer instance.

required
pl_module LightningModule

Lightning module being used for prediction.

required
batch Any

Current batch of data.

required
batch_idx int

Index of the current batch.

required
dataloader_idx int

Index of the dataloader, by default 0.

0

Raises:

Type Description
PredictionStoppedException

If stop_condition() returns True.