Skip to content

training_widget

Training widget.

TrainingWidget #

Bases: QGroupBox

Training widget.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Training status.

None
Source code in src/careamics_napari/widgets/training_widget.py
class TrainingWidget(QGroupBox):
    """Training widget.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Training status.
    """

    def __init__(self: Self, train_status: Optional[TrainingStatus] = None) -> None:
        """Initialize the widget.

        Parameters
        ----------
        train_status : TrainingStatus or None, default=None
            Training status.
        """
        super().__init__()  # TODO needed? and in the other classes? to pass parent?

        self.train_status = (
            TrainingStatus() if train_status is None else train_status  # type: ignore
        )

        self.setTitle("Train")
        self.setLayout(QVBoxLayout())

        # train buttons
        train_buttons = QWidget()
        train_buttons.setLayout(QHBoxLayout())

        self.train_button = QPushButton("Train", self)
        self.train_button.setMinimumWidth(120)

        self.reset_model_button = QPushButton("Reset", self)
        self.reset_model_button.setMinimumWidth(120)
        self.reset_model_button.setEnabled(False)
        self.reset_model_button.setToolTip(
            "Reset the weights of the model (forget the training)"
        )

        train_buttons.layout().addWidget(self.train_button, alignment=Qt.AlignLeft)
        train_buttons.layout().addWidget(self.reset_model_button, alignment=Qt.AlignLeft)
        self.layout().addWidget(train_buttons)

        # actions
        if self.train_status is not None:
            # what to do when the buttons are clicked
            self.train_button.clicked.connect(self._train_stop_clicked)
            self.reset_model_button.clicked.connect(self._reset_clicked)

            # listening to the signal
            self.train_status.events.state.connect(self._update_button)

    def _train_stop_clicked(self) -> None:
        """Update the UI and training status when the train button is clicked."""
        if self.train_status is not None:
            if (
                self.train_status.state == TrainingState.IDLE
                or self.train_status.state == TrainingState.DONE
            ):
                self.train_status.state = TrainingState.TRAINING
                self.reset_model_button.setEnabled(False)
                self.reset_model_button.setText("")
                self.train_button.setText("Stop")

            elif self.train_status.state == TrainingState.TRAINING:
                self.train_status.state = TrainingState.STOPPED
                self.train_button.setText("Train")
                self.reset_model_button.setEnabled(True)
                self.reset_model_button.setText("Reset")

            elif self.train_status.state == TrainingState.STOPPED:
                self.train_status.state = TrainingState.TRAINING
                self.train_button.setText("Stop")

    def _reset_clicked(self) -> None:
        """Update the UI and training status when the reset button is clicked."""
        if self.train_status is not None:
            if self.train_status.state != TrainingState.TRAINING:
                self.train_status.state = TrainingState.IDLE
                self.train_button.setText("Train")
                self.reset_model_button.setEnabled(False)

    def _update_button(self, new_state: TrainingState) -> None:
        """Update the button text based on the training state.

        Parameters
        ----------
        new_state : TrainingState
            New training state.
        """
        if new_state == TrainingState.DONE or new_state == TrainingState.STOPPED:
            self.train_button.setText("Train")
            self.reset_model_button.setEnabled(True)
            self.reset_model_button.setText("Reset")
        elif new_state == TrainingState.CRASHED:
            self._reset_clicked()

__init__(train_status=None) #

Initialize the widget.

Parameters:

Name Type Description Default
train_status TrainingStatus or None

Training status.

None
Source code in src/careamics_napari/widgets/training_widget.py
def __init__(self: Self, train_status: Optional[TrainingStatus] = None) -> None:
    """Initialize the widget.

    Parameters
    ----------
    train_status : TrainingStatus or None, default=None
        Training status.
    """
    super().__init__()  # TODO needed? and in the other classes? to pass parent?

    self.train_status = (
        TrainingStatus() if train_status is None else train_status  # type: ignore
    )

    self.setTitle("Train")
    self.setLayout(QVBoxLayout())

    # train buttons
    train_buttons = QWidget()
    train_buttons.setLayout(QHBoxLayout())

    self.train_button = QPushButton("Train", self)
    self.train_button.setMinimumWidth(120)

    self.reset_model_button = QPushButton("Reset", self)
    self.reset_model_button.setMinimumWidth(120)
    self.reset_model_button.setEnabled(False)
    self.reset_model_button.setToolTip(
        "Reset the weights of the model (forget the training)"
    )

    train_buttons.layout().addWidget(self.train_button, alignment=Qt.AlignLeft)
    train_buttons.layout().addWidget(self.reset_model_button, alignment=Qt.AlignLeft)
    self.layout().addWidget(train_buttons)

    # actions
    if self.train_status is not None:
        # what to do when the buttons are clicked
        self.train_button.clicked.connect(self._train_stop_clicked)
        self.reset_model_button.clicked.connect(self._reset_clicked)

        # listening to the signal
        self.train_status.events.state.connect(self._update_button)