Skip to content

callback

PyTorch Lightning callback used to update GUI with progress.

UpdaterCallBack #

Bases: Callback

PyTorch Lightning callback for updating training and prediction UI states.

Parameters:

Name Type Description Default
training_queue Queue

Training queue used to pass updates between threads.

required
prediction_queue Queue

Prediction queue used to pass updates between threads.

required

Attributes:

Name Type Description
training_queue Queue

Training queue used to pass updates between threads.

prediction_queue Queue

Prediction queue used to pass updates between threads.

Source code in src/careamics_napari/careamics_utils/callback.py
class UpdaterCallBack(Callback):
    """PyTorch Lightning callback for updating training and prediction UI states.

    Parameters
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.

    Attributes
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.
    """

    def __init__(self: Self, training_queue: Queue, prediction_queue: Queue) -> None:
        """Initialize the callback.

        Parameters
        ----------
        training_queue : Queue
            Training queue used to pass updates between threads.
        prediction_queue : Queue
            Prediction queue used to pass updates between threads.
        """
        # TODO: the training queue should be optional in case of prediction only
        self.training_queue = training_queue
        self.prediction_queue = prediction_queue

    def get_train_queue(self) -> Queue:
        """Return the training queue.

        Returns
        -------
        Queue
            Training queue.
        """
        return self.training_queue

    def get_predict_queue(self) -> Queue:
        """Return the prediction queue.

        Returns
        -------
        Queue
            Prediction queue.
        """
        return self.prediction_queue

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the beginning of the training.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        # compute the number of batches
        len_dataloader = len(trainer.train_dataloader)  # type: ignore

        self.training_queue.put(
            TrainUpdate(
                TrainUpdateType.MAX_BATCH,
                int(len_dataloader / trainer.accumulate_grad_batches),
            )
        )

        # register number of epochs
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.MAX_EPOCH, trainer.max_epochs)
        )

    def on_train_epoch_start(
        self, trainer: Trainer, pl_module: LightningModule
    ) -> None:
        """Method called at the beginning of each epoch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.EPOCH, trainer.current_epoch)
        )

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the end of each epoch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        metrics = trainer.progress_bar_metrics

        if "train_loss_epoch" in metrics:
            self.training_queue.put(
                TrainUpdate(TrainUpdateType.LOSS, metrics["train_loss_epoch"])
            )

        if "val_loss" in metrics:
            self.training_queue.put(
                TrainUpdate(TrainUpdateType.VAL_LOSS, metrics["val_loss"])
            )

    def on_train_batch_start(
        self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
    ) -> None:
        """Method called at the beginning of each batch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        batch : Any
            Batch.
        batch_idx : int
            Index of the batch.
        """
        self.training_queue.put(TrainUpdate(TrainUpdateType.BATCH, batch_idx))

    def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Method called at the beginning of the prediction.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        """
        self.prediction_queue.put(
            PredictionUpdate(
                PredictionUpdateType.MAX_SAMPLES,
                # lightning returns a number of batches per dataloader
                trainer.num_predict_batches[0],
            )
        )

    def on_predict_batch_start(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        """Method called at the beginning of each prediction batch.

        Parameters
        ----------
        trainer : Trainer
            PyTorch Lightning trainer.
        pl_module : LightningModule
            PyTorch Lightning module.
        batch : Any
            Batch.
        batch_idx : int
            Index of the batch.
        dataloader_idx : int, default=0
            Index of the dataloader.
        """
        self.prediction_queue.put(
            PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, batch_idx)
        )

__init__(training_queue, prediction_queue) #

Initialize the callback.

Parameters:

Name Type Description Default
training_queue Queue

Training queue used to pass updates between threads.

required
prediction_queue Queue

Prediction queue used to pass updates between threads.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def __init__(self: Self, training_queue: Queue, prediction_queue: Queue) -> None:
    """Initialize the callback.

    Parameters
    ----------
    training_queue : Queue
        Training queue used to pass updates between threads.
    prediction_queue : Queue
        Prediction queue used to pass updates between threads.
    """
    # TODO: the training queue should be optional in case of prediction only
    self.training_queue = training_queue
    self.prediction_queue = prediction_queue

get_predict_queue() #

Return the prediction queue.

Returns:

Type Description
Queue

Prediction queue.

Source code in src/careamics_napari/careamics_utils/callback.py
def get_predict_queue(self) -> Queue:
    """Return the prediction queue.

    Returns
    -------
    Queue
        Prediction queue.
    """
    return self.prediction_queue

get_train_queue() #

Return the training queue.

Returns:

Type Description
Queue

Training queue.

Source code in src/careamics_napari/careamics_utils/callback.py
def get_train_queue(self) -> Queue:
    """Return the training queue.

    Returns
    -------
    Queue
        Training queue.
    """
    return self.training_queue

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0) #

Method called at the beginning of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
batch Any

Batch.

required
batch_idx int

Index of the batch.

required
dataloader_idx int

Index of the dataloader.

0
Source code in src/careamics_napari/careamics_utils/callback.py
def on_predict_batch_start(
    self,
    trainer: Trainer,
    pl_module: LightningModule,
    batch: Any,
    batch_idx: int,
    dataloader_idx: int = 0,
) -> None:
    """Method called at the beginning of each prediction batch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    batch : Any
        Batch.
    batch_idx : int
        Index of the batch.
    dataloader_idx : int, default=0
        Index of the dataloader.
    """
    self.prediction_queue.put(
        PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, batch_idx)
    )

on_predict_start(trainer, pl_module) #

Method called at the beginning of the prediction.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the beginning of the prediction.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    self.prediction_queue.put(
        PredictionUpdate(
            PredictionUpdateType.MAX_SAMPLES,
            # lightning returns a number of batches per dataloader
            trainer.num_predict_batches[0],
        )
    )

on_train_batch_start(trainer, pl_module, batch, batch_idx) #

Method called at the beginning of each batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
batch Any

Batch.

required
batch_idx int

Index of the batch.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def on_train_batch_start(
    self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
    """Method called at the beginning of each batch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    batch : Any
        Batch.
    batch_idx : int
        Index of the batch.
    """
    self.training_queue.put(TrainUpdate(TrainUpdateType.BATCH, batch_idx))

on_train_epoch_end(trainer, pl_module) #

Method called at the end of each epoch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the end of each epoch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    metrics = trainer.progress_bar_metrics

    if "train_loss_epoch" in metrics:
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.LOSS, metrics["train_loss_epoch"])
        )

    if "val_loss" in metrics:
        self.training_queue.put(
            TrainUpdate(TrainUpdateType.VAL_LOSS, metrics["val_loss"])
        )

on_train_epoch_start(trainer, pl_module) #

Method called at the beginning of each epoch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def on_train_epoch_start(
    self, trainer: Trainer, pl_module: LightningModule
) -> None:
    """Method called at the beginning of each epoch.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    self.training_queue.put(
        TrainUpdate(TrainUpdateType.EPOCH, trainer.current_epoch)
    )

on_train_start(trainer, pl_module) #

Method called at the beginning of the training.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
Source code in src/careamics_napari/careamics_utils/callback.py
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
    """Method called at the beginning of the training.

    Parameters
    ----------
    trainer : Trainer
        PyTorch Lightning trainer.
    pl_module : LightningModule
        PyTorch Lightning module.
    """
    # compute the number of batches
    len_dataloader = len(trainer.train_dataloader)  # type: ignore

    self.training_queue.put(
        TrainUpdate(
            TrainUpdateType.MAX_BATCH,
            int(len_dataloader / trainer.accumulate_grad_batches),
        )
    )

    # register number of epochs
    self.training_queue.put(
        TrainUpdate(TrainUpdateType.MAX_EPOCH, trainer.max_epochs)
    )