Skip to content

training_widget

Training widget.

TrainingWidget #

Bases: QGroupBox

Training widget.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Training status.

None
Source code in src/careamics_napari/widgets/training_widget.py
class TrainingWidget(QGroupBox):
    """Training widget.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Training status.
    """

    def __init__(self, train_status: TrainingStatus | None = None) -> None:
        """Initialize the widget.

        Parameters
        ----------
        train_status : TrainingStatus or None, default=None
            Training status.
        """
        super().__init__()

        self.train_status = (
            TrainingStatus() if train_status is None else train_status  # type: ignore
        )

        # TODO add val percentage and val minimum ?

        # train button
        self.train_button = QPushButton("Train", self)
        self.train_button.setMinimumWidth(120)
        self.train_button.clicked.connect(self._train_stop_clicked)

        # reset button
        self.reset_model_button = QPushButton("Reset", self)
        self.reset_model_button.setMinimumWidth(120)
        self.reset_model_button.setEnabled(False)
        self.reset_model_button.setToolTip(
            "Reset the weights of the model (forget the training)"
        )
        self.reset_model_button.clicked.connect(self._reset_clicked)

        hbox = QHBoxLayout()
        hbox.addWidget(self.train_button, alignment=Qt.AlignLeft)  # type: ignore
        hbox.addWidget(self.reset_model_button, alignment=Qt.AlignLeft)  # type: ignore

        vbox = QVBoxLayout()
        vbox.addLayout(hbox)
        self.setLayout(vbox)
        self.setTitle("Train")

        # listening to the training status events
        self.train_status.events.state.connect(self._update_button)

    def _train_stop_clicked(self) -> None:
        """Update the UI and training status when the train button is clicked."""
        if self.train_status is not None:
            if (
                self.train_status.state == TrainingState.IDLE
                or self.train_status.state == TrainingState.DONE
            ):
                # important to do it before state change
                self.train_button.setText("Stop")
                self.reset_model_button.setEnabled(False)
                self.train_status.state = TrainingState.TRAINING

            elif self.train_status.state == TrainingState.TRAINING:
                self.train_button.setText("Train")
                self.reset_model_button.setEnabled(True)
                self.train_status.state = TrainingState.STOPPED

            elif self.train_status.state == TrainingState.STOPPED:
                self.train_button.setText("Stop")
                self.train_status.state = TrainingState.TRAINING

    def _reset_clicked(self) -> None:
        """Update the UI and training status when the reset button is clicked."""
        if self.train_status is not None:
            if self.train_status.state != TrainingState.TRAINING:
                self.train_status.state = TrainingState.IDLE
                self.train_button.setText("Train")
                self.reset_model_button.setEnabled(False)

    def _update_button(self, new_state: TrainingState) -> None:
        """Update the button text based on the training state.

        Parameters
        ----------
        new_state : TrainingState
            New training state.
        """
        if new_state == TrainingState.DONE or new_state == TrainingState.STOPPED:
            self.train_button.setText("Train")
            self.reset_model_button.setEnabled(True)
        elif new_state == TrainingState.CRASHED:
            self._reset_clicked()

__init__(train_status=None) #

Initialize the widget.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Training status.

None
Source code in src/careamics_napari/widgets/training_widget.py
def __init__(self, train_status: TrainingStatus | None = None) -> None:
    """Initialize the widget.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Training status.
    """
    super().__init__()

    self.train_status = (
        TrainingStatus() if train_status is None else train_status  # type: ignore
    )

    # TODO add val percentage and val minimum ?

    # train button
    self.train_button = QPushButton("Train", self)
    self.train_button.setMinimumWidth(120)
    self.train_button.clicked.connect(self._train_stop_clicked)

    # reset button
    self.reset_model_button = QPushButton("Reset", self)
    self.reset_model_button.setMinimumWidth(120)
    self.reset_model_button.setEnabled(False)
    self.reset_model_button.setToolTip(
        "Reset the weights of the model (forget the training)"
    )
    self.reset_model_button.clicked.connect(self._reset_clicked)

    hbox = QHBoxLayout()
    hbox.addWidget(self.train_button, alignment=Qt.AlignLeft)  # type: ignore
    hbox.addWidget(self.reset_model_button, alignment=Qt.AlignLeft)  # type: ignore

    vbox = QVBoxLayout()
    vbox.addLayout(hbox)
    self.setLayout(vbox)
    self.setTitle("Train")

    # listening to the training status events
    self.train_status.events.state.connect(self._update_button)