Skip to content

stop_prediction_callback

Callback for stopping prediction based on external condition.

PredictionStoppedException #

Bases: Exception

Exception raised when prediction is stopped by external signal.

Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
class PredictionStoppedException(Exception):
    """Exception raised when prediction is stopped by external signal."""

    pass

StopPredictionCallback #

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

Name Type Description Default
stop_condition Callable[[], bool]

A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

required
Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
class StopPredictionCallback(Callback):
    """PyTorch Lightning callback to stop prediction based on external condition.

    This callback monitors a user-provided stop condition at the start of each
    prediction batch. When the condition is met, the callback stops the trainer
    and raises PredictionStoppedException to interrupt the prediction loop.

    Parameters
    ----------
    stop_condition : Callable[[], bool]
        A callable that returns True when prediction should stop. The callable
        is invoked at the start of each prediction batch.
    """

    def __init__(self, stop_condition: Callable[[], bool]) -> None:
        """Initialize the callback with a stop condition.

        Parameters
        ----------
        stop_condition : Callable[[], bool]
            Function that returns True when prediction should stop.
        """
        super().__init__()
        self.stop_condition = stop_condition

    def on_predict_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Check stop condition at the start of each prediction batch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer instance.
        pl_module : LightningModule
            Lightning module being used for prediction.
        batch : Any
            Current batch of data.
        batch_idx : int
            Index of the current batch.
        dataloader_idx : int, optional
            Index of the dataloader, by default 0.

        Raises
        ------
        PredictionStoppedException
            If stop_condition() returns True.
        """
        if self.stop_condition():
            trainer.should_stop = True
            raise PredictionStoppedException("Prediction stopped by user")

__init__(stop_condition) #

Initialize the callback with a stop condition.

Parameters:

Name Type Description Default
stop_condition Callable[[], bool]

Function that returns True when prediction should stop.

required
Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
def __init__(self, stop_condition: Callable[[], bool]) -> None:
    """Initialize the callback with a stop condition.

    Parameters
    ----------
    stop_condition : Callable[[], bool]
        Function that returns True when prediction should stop.
    """
    super().__init__()
    self.stop_condition = stop_condition

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0) #

Check stop condition at the start of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer instance.

required
pl_module LightningModule

Lightning module being used for prediction.

required
batch Any

Current batch of data.

required
batch_idx int

Index of the current batch.

required
dataloader_idx int

Index of the dataloader, by default 0.

0

Raises:

Type Description
PredictionStoppedException

If stop_condition() returns True.

Source code in src/careamics/lightning/callbacks/stop_prediction_callback.py
def on_predict_batch_start(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Check stop condition at the start of each prediction batch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer instance.
    pl_module : LightningModule
        Lightning module being used for prediction.
    batch : Any
        Current batch of data.
    batch_idx : int
        Index of the current batch.
    dataloader_idx : int, optional
        Index of the dataloader, by default 0.

    Raises
    ------
    PredictionStoppedException
        If stop_condition() returns True.
    """
    if self.stop_condition():
        trainer.should_stop = True
        raise PredictionStoppedException("Prediction stopped by user")