Skip to content

callbacks

PyTorch Lightning callback used to update GUI with progress.

PredictionStoppedException #

Bases: Exception

Exception raised when prediction is stopped by user.

Source code in src/careamics_napari/careamics_utils/callbacks.py
class PredictionStoppedException(Exception):
    """Exception raised when prediction is stopped by user."""

    pass

StopPredictionCallback #

Bases: Callback

PyTorch Lightning callback to stop prediction when signaled.

This callback monitors a PredictionStatus object and stops the trainer when the state is set to STOPPED, allowing for graceful interruption of prediction processes.

Parameters:

Name Type Description Default
pred_status PredictionStatus

Prediction status object that when set to STOPPED, signals the prediction to stop.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
class StopPredictionCallback(Callback):
    """PyTorch Lightning callback to stop prediction when signaled.

    This callback monitors a PredictionStatus object and stops the trainer
    when the state is set to STOPPED, allowing for graceful interruption of
    prediction processes.

    Parameters
    ----------
    pred_status : PredictionStatus
        Prediction status object that when set to STOPPED, signals the prediction to stop.
    """

    def __init__(self, pred_status: PredictionStatus) -> None:
        """Initialize the callback.

        Parameters
        ----------
        pred_status : PredictionStatus
            Prediction status object that when set to STOPPED,
            signals the prediction to stop.
        """
        super().__init__()
        self.pred_status = pred_status

    def on_predict_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Check for stop signal at the start of each prediction batch.

        Parameters
        ----------
        trainer : pl.Trainer
            The PyTorch Lightning trainer.
        pl_module : pl.LightningModule
            The Lightning module being used.
        batch : Any
            The current batch of data.
        batch_idx : int
            Index of the current batch.
        dataloader_idx : int, optional
            Index of the current dataloader, by default 0.
        """
        if self.pred_status.state == PredictionState.STOPPED:
            print("Stop signal received, stopping prediction...")
            trainer.should_stop = True
            # For prediction, we need to raise an exception to actually stop
            raise PredictionStoppedException("Prediction stopped by user")

__init__(pred_status) #

Initialize the callback.

Parameters:

Name Type Description Default
pred_status PredictionStatus

Prediction status object that when set to STOPPED, signals the prediction to stop.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def __init__(self, pred_status: PredictionStatus) -> None:
    """Initialize the callback.

    Parameters
    ----------
    pred_status : PredictionStatus
        Prediction status object that when set to STOPPED,
        signals the prediction to stop.
    """
    super().__init__()
    self.pred_status = pred_status

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0) #

Check for stop signal at the start of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

The PyTorch Lightning trainer.

required
pl_module LightningModule

The Lightning module being used.

required
batch Any

The current batch of data.

required
batch_idx int

Index of the current batch.

required
dataloader_idx int

Index of the current dataloader, by default 0.

0
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_predict_batch_start(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Check for stop signal at the start of each prediction batch.

    Parameters
    ----------
    trainer : pl.Trainer
        The PyTorch Lightning trainer.
    pl_module : pl.LightningModule
        The Lightning module being used.
    batch : Any
        The current batch of data.
    batch_idx : int
        Index of the current batch.
    dataloader_idx : int, optional
        Index of the current dataloader, by default 0.
    """
    if self.pred_status.state == PredictionState.STOPPED:
        print("Stop signal received, stopping prediction...")
        trainer.should_stop = True
        # For prediction, we need to raise an exception to actually stop
        raise PredictionStoppedException("Prediction stopped by user")

UpdaterCallBack #

Bases: Callback

PyTorch Lightning callback for updating training and prediction UI states.

Parameters:

Name Type Description Default
training_queue Queue

Training queue used to pass updates between threads.

required
prediction_queue Queue

Prediction queue used to pass updates between threads.

required

Attributes:

Name Type Description
training_queue Queue

Training queue used to pass updates between threads.

prediction_queue Queue

Prediction queue used to pass updates between threads.

Source code in src/careamics_napari/careamics_utils/callbacks.py
class UpdaterCallBack(Callback):
    """PyTorch Lightning callback for updating training and prediction UI states.

    Parameters
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.

    Attributes
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.
    """

    def __init__(self, training_queue: Queue, prediction_queue: Queue) -> None:
        """Initialize the callback.

        Parameters
        ----------
        training_queue : Queue
            Training queue used to pass updates between threads.
        prediction_queue : Queue
            Prediction queue used to pass updates between threads.
        """
        self.training_queue = training_queue
        self.prediction_queue = prediction_queue

    def get_train_queue(self) -> Queue:
        """Return the training queue.

        Returns
        -------
        Queue
            Training queue.
        """
        return self.training_queue

    def get_predict_queue(self) -> Queue:
        """Return the prediction queue.

        Returns
        -------
        Queue
            Prediction queue.
        """
        return self.prediction_queue

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the beginning of the training.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        # compute the number of batches
        len_dataloader = len(trainer.train_dataloader)  # type: ignore

        self.training_queue.put(
            TrainUpdate(
                TrainUpdateType.MAX_BATCH,
                int(len_dataloader / trainer.accumulate_grad_batches),
            )
        )

        # register number of epochs
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.MAX_EPOCH, trainer.max_epochs)
        )

    def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the beginning of each epoch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        self.training_queue.put(TrainUpdate(TrainUpdateType.EPOCH, trainer.current_epoch))

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the end of each epoch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        metrics = trainer.progress_bar_metrics

        if "train_loss_epoch" in metrics:
            self.training_queue.put(
                TrainUpdate(TrainUpdateType.LOSS, metrics["train_loss_epoch"])
            )

        if "val_loss" in metrics:
            self.training_queue.put(
                TrainUpdate(TrainUpdateType.VAL_LOSS, metrics["val_loss"])
            )

    def on_train_batch_start(
        self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
    ) -> None:
        """Method called at the beginning of each batch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        batch : Any
            Batch.
        batch_idx : int
            Index of the batch.
        """
        self.training_queue.put(TrainUpdate(TrainUpdateType.BATCH, batch_idx))

    def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the beginning of the prediction.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        # lightning returns a number of batches per dataloader
        # if data is loading from disk, the IterableDataset length is not defined.
        n_batches = trainer.num_predict_batches[0]
        if n_batches == np.inf:
            n_batches = "?"
        else:
            n_batches = int(n_batches)

        self.prediction_queue.put(
            PredictionUpdate(
                PredictionUpdateType.MAX_SAMPLES,
                n_batches,
            )
        )

    def on_predict_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Method called at the beginning of each prediction batch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        batch : Any
            Batch.
        batch_idx : int
            Index of the batch.
        dataloader_idx : int, default=0
            Index of the dataloader.
        """
        self.prediction_queue.put(
            PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, batch_idx)
        )

__init__(training_queue, prediction_queue) #

Initialize the callback.

Parameters:

Name Type Description Default
training_queue Queue

Training queue used to pass updates between threads.

required
prediction_queue Queue

Prediction queue used to pass updates between threads.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def __init__(self, training_queue: Queue, prediction_queue: Queue) -> None:
    """Initialize the callback.

    Parameters
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.
    """
    self.training_queue = training_queue
    self.prediction_queue = prediction_queue

get_predict_queue() #

Return the prediction queue.

Returns:

Type Description
Queue

Prediction queue.

Source code in src/careamics_napari/careamics_utils/callbacks.py
def get_predict_queue(self) -> Queue:
    """Return the prediction queue.

    Returns
    -------
    Queue
        Prediction queue.
    """
    return self.prediction_queue

get_train_queue() #

Return the training queue.

Returns:

Type Description
Queue

Training queue.

Source code in src/careamics_napari/careamics_utils/callbacks.py
def get_train_queue(self) -> Queue:
    """Return the training queue.

    Returns
    -------
    Queue
        Training queue.
    """
    return self.training_queue

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0) #

Method called at the beginning of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
batch Any

Batch.

required
batch_idx int

Index of the batch.

required
dataloader_idx int

Index of the dataloader.

0
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_predict_batch_start(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Method called at the beginning of each prediction batch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    batch : Any
        Batch.
    batch_idx : int
        Index of the batch.
    dataloader_idx : int, default=0
        Index of the dataloader.
    """
    self.prediction_queue.put(
        PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, batch_idx)
    )

on_predict_start(trainer, pl_module) #

Method called at the beginning of the prediction.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the beginning of the prediction.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    # lightning returns a number of batches per dataloader
    # if data is loading from disk, the IterableDataset length is not defined.
    n_batches = trainer.num_predict_batches[0]
    if n_batches == np.inf:
        n_batches = "?"
    else:
        n_batches = int(n_batches)

    self.prediction_queue.put(
        PredictionUpdate(
            PredictionUpdateType.MAX_SAMPLES,
            n_batches,
        )
    )

on_train_batch_start(trainer, pl_module, batch, batch_idx) #

Method called at the beginning of each batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
batch Any

Batch.

required
batch_idx int

Index of the batch.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_train_batch_start(
    self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
    """Method called at the beginning of each batch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    batch : Any
        Batch.
    batch_idx : int
        Index of the batch.
    """
    self.training_queue.put(TrainUpdate(TrainUpdateType.BATCH, batch_idx))

on_train_epoch_end(trainer, pl_module) #

Method called at the end of each epoch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the end of each epoch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    metrics = trainer.progress_bar_metrics

    if "train_loss_epoch" in metrics:
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.LOSS, metrics["train_loss_epoch"])
        )

    if "val_loss" in metrics:
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.VAL_LOSS, metrics["val_loss"])
        )

on_train_epoch_start(trainer, pl_module) #

Method called at the beginning of each epoch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the beginning of each epoch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    self.training_queue.put(TrainUpdate(TrainUpdateType.EPOCH, trainer.current_epoch))

on_train_start(trainer, pl_module) #

Method called at the beginning of the training.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callbacks.py
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the beginning of the training.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    # compute the number of batches
    len_dataloader = len(trainer.train_dataloader)  # type: ignore

    self.training_queue.put(
        TrainUpdate(
            TrainUpdateType.MAX_BATCH,
            int(len_dataloader / trainer.accumulate_grad_batches),
        )
    )

    # register number of epochs
    self.training_queue.put(
        TrainUpdate(TrainUpdateType.MAX_EPOCH, trainer.max_epochs)
    )