Skip to content

train_data_widget

A widget allowing users to select data source for the training.

TrainDataWidget #

Bases: QTabWidget

A widget offering to select layers from napari or paths from disk.

Parameters:

Name Type Description Default
signal TrainConfigurationSignal or None

Signal representing the training parameters.

None
use_target bool

Whether to target fields.

False
Source code in src/careamics_napari/widgets/train_data_widget.py
class TrainDataWidget(QTabWidget):
    """A widget offering to select layers from napari or paths from disk.

    Parameters
    ----------
    signal : TrainConfigurationSignal or None, default=None
        Signal representing the training parameters.
    use_target : bool, default=False
        Whether to target fields.
    """

    def __init__(
        self: Self,
        signal: Optional[TrainingSignal] = None,
        use_target: bool = False,
    ) -> None:
        """Initialize the widget.

        Parameters
        ----------
        signal : TrainConfigurationSignal or None, default=None
            Signal representing the training parameters.
        use_target : bool, default=False
            Whether to target fields.
        """
        super().__init__()
        self.config_signal = signal
        self.use_target = use_target

        # QTabs
        layer_tab = QWidget()
        layer_tab.setLayout(QVBoxLayout())
        disk_tab = QWidget()
        disk_tab.setLayout(QVBoxLayout())

        # add tabs
        self.addTab(layer_tab, "From layers")
        self.addTab(disk_tab, "From disk")
        self.setTabToolTip(0, "Use images from napari layers")
        self.setTabToolTip(1, "Use patches saved on the disk")

        # set tabs
        self._set_layer_tab(layer_tab)
        self._set_disk_tab(disk_tab)

        # self.setMaximumHeight(400 if self.use_target else 200)

        # set actions
        if self.config_signal is not None:
            self.currentChanged.connect(self._set_data_source)
            self._set_data_source(self.currentIndex())

    def _set_layer_tab(
        self: Self,
        layer_tab: QWidget,
    ) -> None:
        """Set up the layer tab.

        Parameters
        ----------
        layer_tab : QWidget
            Layer tab widget.
        """
        if _has_napari and napari.current_viewer() is not None:
            form = QFormLayout()
            form.setFormAlignment(Qt.AlignLeft | Qt.AlignTop)
            form.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow)
            form.setContentsMargins(12, 12, 0, 0)

            self.img_train = layer_choice()
            self.img_train.native.setToolTip("Select a training layer.")

            self.img_val = layer_choice()
            self.img_train.native.setToolTip("Select a validation layer.")

            # connection actions for images
            self.img_train.changed.connect(self._update_train_layer)
            self.img_val.changed.connect(self._update_val_layer)

            # to cover the case when image was loaded before the plugin
            if self.img_train.value is not None:
                self._update_train_layer(self.img_train.value)
            if self.img_val.value is not None:
                self._update_val_layer(self.img_val.value)

            if self.use_target:
                # get the target layers
                self.target_train = layer_choice()
                self.target_val = layer_choice()

                # tool tips
                self.target_train.native.setToolTip("Select a training target layer.")
                self.target_val.native.setToolTip("Select a validation target layer.")

                # connection actions for targets
                self.target_train.changed.connect(self._update_train_target_layer)
                self.target_val.changed.connect(self._update_val_target_layer)

                # to cover the case when image was loaded before the plugin
                if self.target_train.value is not None:
                    self._update_train_target_layer(self.target_train.value)
                if self.target_val.value is not None:
                    self._update_val_target_layer(self.target_val.value)

                form.addRow("Train", self.img_train.native)
                form.addRow("Val", self.img_val.native)
                form.addRow("Train target", self.target_train.native)
                form.addRow("Val target", self.target_val.native)

            else:
                form.addRow("Train", self.img_train.native)
                form.addRow("Val", self.img_val.native)

            # layer_tab.layout().addWidget(widget_layers)
            layer_tab.layout().addLayout(form)

        else:
            # simply remove the tab
            self.removeTab(0)

    def _set_disk_tab(self: Self, disk_tab: QWidget) -> None:
        """Set up the disk tab.

        Parameters
        ----------
        disk_tab : QWidget
            Disk tab widget.
        """
        # disk tab
        buttons = QWidget()
        form = QFormLayout()
        form.setFormAlignment(Qt.AlignLeft | Qt.AlignTop)
        form.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow)
        # form.setSpacing(0)

        self.train_images_folder = FolderWidget("Choose")
        self.val_images_folder = FolderWidget("Choose")
        form.addRow("Train", self.train_images_folder)
        form.addRow("Val", self.val_images_folder)

        if self.use_target:
            self.train_target_folder = FolderWidget("Choose")
            self.val_target_folder = FolderWidget("Choose")

            form.addRow("Train target", self.train_target_folder)
            form.addRow("Val target", self.val_target_folder)

            self.train_target_folder.setToolTip(
                "Select a folder containing the training\n" "target."
            )
            self.val_target_folder.setToolTip(
                "Select a folder containing the validation\n" "target."
            )
            self.train_images_folder.setToolTip(
                "Select a folder containing the training\n" "images."
            )
            self.val_images_folder.setToolTip(
                "Select a folder containing the validation\n" "images."
            )

            # add actions to target
            self.train_target_folder.get_text_widget().textChanged.connect(
                self._update_train_target_folder
            )
            self.val_target_folder.get_text_widget().textChanged.connect(
                self._update_val_target_folder
            )

        else:
            self.train_images_folder.setToolTip(
                "Select a folder containing the training\n" "images."
            )
            self.val_images_folder.setToolTip(
                "Select a folder containing the validation\n"
                "images, if you select the same folder as\n"
                "for training, the validation patches will\n"
                "be extracted from the training data."
            )

        # add actions
        self.train_images_folder.get_text_widget().textChanged.connect(
            self._update_train_folder
        )
        self.val_images_folder.get_text_widget().textChanged.connect(
            self._update_val_folder
        )

        buttons.setLayout(form)
        disk_tab.layout().addWidget(buttons)

    def _set_data_source(self: Self, index: int) -> None:
        """Set the signal data source to the selected tab.

        Parameters
        ----------
        index : int
            Index of the selected tab.
        """
        if self.config_signal is not None:
            self.config_signal.load_from_disk = index == self.count() - 1

    def _update_train_layer(self: Self, layer: Image) -> None:
        """Update the training layer.

        Parameters
        ----------
        layer : Image
            Training layer.
        """
        if self.config_signal is not None:
            self.config_signal.layer_train = layer

    def _update_val_layer(self: Self, layer: Image) -> None:
        """Update the validation layer.

        Parameters
        ----------
        layer : Image
            Validation layer.
        """
        if self.config_signal is not None:
            self.config_signal.layer_val = layer

    def _update_train_target_layer(self: Self, layer: Image) -> None:
        """Update the training target layer.

        Parameters
        ----------
        layer : Image
            Training target layer.
        """
        if self.config_signal is not None:
            self.config_signal.layer_train_target = layer

    def _update_val_target_layer(self: Self, layer: Image) -> None:
        """Update the validation target layer.

        Parameters
        ----------
        layer : Image
            Validation target layer.
        """
        if self.config_signal is not None:
            self.config_signal.layer_val_target = layer

    def _update_train_folder(self: Self, folder: str) -> None:
        """Update the training folder.

        Parameters
        ----------
        folder : str
            Training folder.
        """
        if self.config_signal is not None:
            self.config_signal.path_train = folder

    def _update_val_folder(self: Self, folder: str) -> None:
        """Update the validation folder.

        Parameters
        ----------
        folder : str
            Validation folder.
        """
        if self.config_signal is not None:
            self.config_signal.path_val = folder

    def _update_train_target_folder(self: Self, folder: str) -> None:
        """Update the training target folder.

        Parameters
        ----------
        folder : str
            Training target folder.
        """
        if self.config_signal is not None:
            self.config_signal.path_train_target = folder

    def _update_val_target_folder(self: Self, folder: str) -> None:
        """Update the validation target folder.

        Parameters
        ----------
        folder : str
            Validation target folder.
        """
        if self.config_signal is not None:
            self.config_signal.path_val_target = folder

__init__(signal=None, use_target=False) #

Initialize the widget.

Parameters:

Name Type Description Default
signal TrainConfigurationSignal or None

Signal representing the training parameters.

None
use_target bool

Whether to target fields.

False
Source code in src/careamics_napari/widgets/train_data_widget.py
def __init__(
    self: Self,
    signal: Optional[TrainingSignal] = None,
    use_target: bool = False,
) -> None:
    """Initialize the widget.

    Parameters
    ----------
    signal : TrainConfigurationSignal or None, default=None
        Signal representing the training parameters.
    use_target : bool, default=False
        Whether to target fields.
    """
    super().__init__()
    self.config_signal = signal
    self.use_target = use_target

    # QTabs
    layer_tab = QWidget()
    layer_tab.setLayout(QVBoxLayout())
    disk_tab = QWidget()
    disk_tab.setLayout(QVBoxLayout())

    # add tabs
    self.addTab(layer_tab, "From layers")
    self.addTab(disk_tab, "From disk")
    self.setTabToolTip(0, "Use images from napari layers")
    self.setTabToolTip(1, "Use patches saved on the disk")

    # set tabs
    self._set_layer_tab(layer_tab)
    self._set_disk_tab(disk_tab)

    # self.setMaximumHeight(400 if self.use_target else 200)

    # set actions
    if self.config_signal is not None:
        self.currentChanged.connect(self._set_data_source)
        self._set_data_source(self.currentIndex())