Skip to content

Stop Prediction Callback

Source

Callback for stopping prediction based on external condition.

PredictionStoppedException

Bases: Exception

Exception raised when prediction is stopped by external signal.

StopPredictionCallback

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

  • stop_condition (Callable[[], bool]) –

    A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

__init__(stop_condition)

Initialize the callback with a stop condition.

Parameters:

  • stop_condition (Callable[[], bool]) –

    Function that returns True when prediction should stop.

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Check stop condition at the start of each prediction batch.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer instance.

  • pl_module (LightningModule) –

    Lightning module being used for prediction.

  • batch (Any) –

    Current batch of data.

  • batch_idx (int) –

    Index of the current batch.

  • dataloader_idx (int, default: 0 ) –

    Index of the dataloader, by default 0.

Raises: