Skip to content

train_progress_widget

A widget displaying the training progress using two progress bars.

TrainProgressWidget #

Bases: QGroupBox

A widget displaying the training progress using two progress bars.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Signal representing the training status.

None
train_config TrainingSignal or None

Signal representing the training parameters.

None
Source code in src/careamics_napari/widgets/train_progress_widget.py
class TrainProgressWidget(QGroupBox):
    """A widget displaying the training progress using two progress bars.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Signal representing the training status.
    train_config : TrainingSignal or None, default=None
        Signal representing the training parameters.
    """

    def __init__(
        self: Self,
        train_status: Optional[TrainingStatus] = None,
        train_config: Optional[TrainingSignal] = None,
    ) -> None:
        """Initialize the widget.

        Parameters
        ----------
        train_status : TrainingStatus or None, default=None
            Signal representing the training status.
        train_config : TrainingSignal or None, default=None
            Signal representing the training parameters.
        """
        super().__init__()

        self.train_status = (
            train_status
            if train_status is not None  # for typing purposes
            else TrainingStatus()  # type: ignore
        )

        self.setTitle("Training progress")
        self.setLayout(QVBoxLayout())

        # progress bars
        self.layout().setContentsMargins(20, 20, 20, 0)

        self.pb_epochs = create_progressbar(
            max_value=self.train_status.max_epochs,
            text_format=f"Epoch ?/{self.train_status.max_epochs}",
            value=0,
        )

        self.pb_batch = create_progressbar(
            max_value=self.train_status.max_batches,
            text_format=f"Batch ?/{self.train_status.max_batches}",
            value=0,
        )

        self.layout().addWidget(self.pb_epochs)
        self.layout().addWidget(self.pb_batch)

        # plot widget
        self.plot = TBPlotWidget(
            max_width=300, max_height=300, min_height=250, train_signal=train_config
        )
        self.layout().addWidget(self.plot.native)

        # actions
        if self.train_status is not None:
            self.train_status.events.state.connect(self._update_training_state)

            self.train_status.events.epoch_idx.connect(self._update_epoch)
            self.train_status.events.max_epochs.connect(self._update_max_epoch)
            self.train_status.events.batch_idx.connect(self._update_batch)
            self.train_status.events.max_batches.connect(self._update_max_batch)
            self.train_status.events.val_loss.connect(self._update_loss)

    def _update_training_state(self: Self, state: TrainingState) -> None:
        """Update the widget according to the training state.

        Parameters
        ----------
        state : TrainingState
            Training state.
        """
        if state == TrainingState.IDLE or state == TrainingState.TRAINING:
            self.plot.clear_plot()

    def _update_max_epoch(self: Self, max_epoch: int):
        """Update the maximum number of epochs in the progress bar.

        Parameters
        ----------
        max_epoch : int
            Maximum number of epochs.
        """
        self.pb_epochs.setMaximum(max_epoch)

    def _update_epoch(self: Self, epoch: int) -> None:
        """Update the epoch progress bar.

        Parameters
        ----------
        epoch : int
            Current epoch.
        """
        self.pb_epochs.setValue(epoch + 1)
        self.pb_epochs.setFormat(f"Epoch {epoch+1}/{self.train_status.max_epochs}")

    def _update_max_batch(self: Self, max_batches: int) -> None:
        """Update the maximum number of batches in the progress bar.

        Parameters
        ----------
        max_batches : int
            Maximum number of batches.
        """
        self.pb_batch.setMaximum(max_batches)

    def _update_batch(self: Self) -> None:
        """Update the batch progress bar."""
        self.pb_batch.setValue(self.train_status.batch_idx + 1)
        self.pb_batch.setFormat(
            f"Batch {self.train_status.batch_idx+1}/{self.train_status.max_batches}"
        )

    def _update_loss(self: Self) -> None:
        """Update the loss plot."""
        self.plot.update_plot(
            epoch=self.train_status.epoch_idx,
            train_loss=self.train_status.loss,
            val_loss=self.train_status.val_loss,
        )

__init__(train_status=None, train_config=None) #

Initialize the widget.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Signal representing the training status.

None
train_config TrainingSignal or None

Signal representing the training parameters.

None
Source code in src/careamics_napari/widgets/train_progress_widget.py
def __init__(
    self: Self,
    train_status: Optional[TrainingStatus] = None,
    train_config: Optional[TrainingSignal] = None,
) -> None:
    """Initialize the widget.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Signal representing the training status.
    train_config : TrainingSignal or None, default=None
        Signal representing the training parameters.
    """
    super().__init__()

    self.train_status = (
        train_status
        if train_status is not None  # for typing purposes
        else TrainingStatus()  # type: ignore
    )

    self.setTitle("Training progress")
    self.setLayout(QVBoxLayout())

    # progress bars
    self.layout().setContentsMargins(20, 20, 20, 0)

    self.pb_epochs = create_progressbar(
        max_value=self.train_status.max_epochs,
        text_format=f"Epoch ?/{self.train_status.max_epochs}",
        value=0,
    )

    self.pb_batch = create_progressbar(
        max_value=self.train_status.max_batches,
        text_format=f"Batch ?/{self.train_status.max_batches}",
        value=0,
    )

    self.layout().addWidget(self.pb_epochs)
    self.layout().addWidget(self.pb_batch)

    # plot widget
    self.plot = TBPlotWidget(
        max_width=300, max_height=300, min_height=250, train_signal=train_config
    )
    self.layout().addWidget(self.plot.native)

    # actions
    if self.train_status is not None:
        self.train_status.events.state.connect(self._update_training_state)

        self.train_status.events.epoch_idx.connect(self._update_epoch)
        self.train_status.events.max_epochs.connect(self._update_max_epoch)
        self.train_status.events.batch_idx.connect(self._update_batch)
        self.train_status.events.max_batches.connect(self._update_max_batch)
        self.train_status.events.val_loss.connect(self._update_loss)