Skip to content

base_plugin

BasePlugin #

Bases: QWidget

CAREamics Base plugin.

Parameters:

Name Type Description Default
napari_viewer Viewer or None

Napari viewer.

None
Source code in src/careamics_napari/base_plugin.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
class BasePlugin(QWidget):
    """CAREamics Base plugin.

    Parameters
    ----------
    napari_viewer : napari.Viewer or None, default=None
        Napari viewer.
    """

    def __init__(
        self,
        napari_viewer: napari.Viewer | None = None,
    ) -> None:
        """Initialize the plugin.

        Parameters
        ----------
        napari_viewer : napari.Viewer or None, default=None
            Napari viewer.
        """
        super().__init__()
        self.viewer = napari_viewer
        self.careamist: CAREamist | None = None  # to hold trained careamist
        self.careamist_loaded: CAREamist | None = None  # to hold loaded careamist

        # create statuses, used to keep track of the threads statuses
        self.train_status = TrainingStatus()  # type: ignore
        self.pred_status = PredictionStatus()  # type: ignore

        # create a careamics config (n2v by default)
        self.careamics_config = get_default_n2v_config()

        # create queues, used to communicate between the threads and the UI
        self._training_queue: Queue = Queue(10)
        self._prediction_queue: Queue = Queue(10)

        # changes from the training and prediction
        self.train_status.events.state.connect(self._training_state_changed)
        self.pred_status.events.state.connect(self._prediction_state_changed)

        # main layout
        self.base_layout = QVBoxLayout()
        # scrolling content
        scroll_content = QWidget()
        scroll_content.setLayout(self.base_layout)
        scroll = ScrollWidgetWrapper(scroll_content)
        vbox = QVBoxLayout()
        vbox.addWidget(scroll)
        self.setLayout(vbox)
        self.setMinimumWidth(200)

        # calling add_*_ui methods will be happened in sub-classes
        # to allow more flexibility while saving some code duplication.

    def add_careamics_banner(self, desc: str = "") -> None:
        """Add the CAREamics banner and GPU label to the plugin."""
        if len(desc) == 0:
            desc = "CAREamics UI for training denoising models."
        self.base_layout.addWidget(
            CAREamicsBanner(
                title="CAREamics",
                short_desc=(desc),
            )
        )
        # GPU label
        gpu_button = create_gpu_label()
        gpu_button.setAlignment(Qt.AlignmentFlag.AlignRight)
        gpu_button.setContentsMargins(0, 5, 0, 0)  # top margin
        self.base_layout.addWidget(gpu_button)

    def add_train_input_ui(self, use_target: bool = False) -> None:
        """Add the train input data selection UI to the plugin."""
        self.input_data_widget = TrainDataWidget(
            careamics_config=self.careamics_config, use_target=use_target
        )
        self.base_layout.addWidget(self.input_data_widget)

    def add_config_ui(self) -> None:
        """Add the training configuration UI to the plugin."""
        self.config_widget = ConfigurationWidget(self.careamics_config)
        self.config_widget.enable_3d_chkbox.clicked.connect(self._set_pred_3d)
        self.config_widget.show_advanced_config.connect(self.show_advanced_config)
        self.base_layout.addWidget(self.config_widget)

    def add_train_button_ui(self) -> None:
        """Add the training button UI to the plugin."""
        self.train_widget = TrainingWidget(self.train_status)
        self.progress_widget = TrainProgressWidget(
            self.careamics_config, self.train_status
        )
        self.base_layout.addWidget(self.train_widget)
        self.base_layout.addWidget(self.progress_widget)

    def add_prediction_ui(self) -> None:
        """Add the prediction UI to the plugin."""
        self.prediction_widget = PredictionWidget(
            self.careamics_config,
            self.train_status,
            self.pred_status,
            self._prediction_queue,
        )
        self.base_layout.addWidget(self.prediction_widget)
        # to get loaded careamist
        self.prediction_widget.careamist_loaded.connect(self._on_careamist_loaded)
        self.prediction_widget.model_from_disk.connect(self._model_selection_changed)

    def add_model_export_ui(self) -> None:
        """Add the model saving UI to the plugin."""
        self.saving_widget = SavingWidget(
            self.careamics_config,
            self.train_status,
        )
        self.saving_widget.export_model.connect(self.export_model)
        self.base_layout.addWidget(self.saving_widget)

    def update_config(self) -> None:
        """Update the configuration from the UI."""
        if self.config_widget is not None:
            self.config_widget.update_config()

        if self.prediction_widget is not None:
            self.prediction_widget.update_config()

        print(f"update_config:\n{self.careamics_config}")

    def export_model(self, destination: Path, export_type: str) -> None:
        """Export the trained model."""
        if self.careamist is None:
            if _has_napari:
                ntf.show_info("No trained model is available for exporting.")
            return

        dims = "3D" if self.careamics_config.is_3D else "2D"
        algo_name = self.careamics_config.algorithm_config.get_algorithm_friendly_name()
        name = f"{algo_name}_{dims}_{self.careamics_config.experiment_name}"

        try:
            if export_type == ExportType.BMZ.value:
                self._prepare_export_to_bmz(destination, name)
            else:
                name = name + ".ckpt"
                self.careamist.trainer.save_checkpoint(
                    destination.joinpath(name),
                )
                print(f"Model exported at {destination}")
                if _has_napari:
                    ntf.show_info(f"Model exported at {destination}")

        except Exception as e:
            traceback.print_exc()
            if _has_napari:
                ntf.show_error(str(e))

    def show_advanced_config(self):
        """Show advanced configuration options."""
        raise NotImplementedError("Advanced configuration options are not implemented.")

    def _set_pred_3d(self, state: bool) -> None:
        """Set the 3D mode flag in the prediction widget.

        Parameters
        ----------
        state : bool
            3D mode.
        """
        if self.prediction_widget is not None:
            self.prediction_widget.set_3d(state)

    def _training_state_changed(self, state: TrainingState) -> None:
        """Handle training state changes.

        This includes starting and stopping training.

        Parameters
        ----------
        state : TrainingState
            New state.
        """
        if state == TrainingState.TRAINING:
            # get data sources
            data_sources = self.input_data_widget.get_data_sources()
            if data_sources is None:
                ntf.show_info("Please set the training data first.")
                self.train_status.state = TrainingState.IDLE
                self.train_widget.train_button.setText("Train")
                return

            # update configuration from ui
            self.update_config()
            print(self.careamics_config)

            # start the training thread
            self.train_worker = train_worker(
                self.careamics_config,
                data_sources,
                self._training_queue,
                self._prediction_queue,
                self.careamist,
                self.pred_status,
            )
            self.train_worker.yielded.connect(self._update_from_training)
            self.train_worker.start()

        elif state == TrainingState.STOPPED:
            if self.careamist is not None:
                self.careamist.stop_training()

        elif state == TrainingState.CRASHED or state == TrainingState.IDLE:
            del self.careamist
            self.careamist = None

        # update prediction widget
        if self.prediction_widget is not None:
            self.prediction_widget.update_button_from_train(state)

    def _prediction_state_changed(self, state: PredictionState) -> None:
        """Handle prediction state changes.

        Parameters
        ----------
        state : PredictionState
            New state.
        """
        # if self.careamist is None and self.careamist_loaded is None:
        #     ntf.show_info("No trained or loaded model is available for prediction.")
        #     self.pred_status.state = PredictionState.STOPPED
        #     return
        careamist = self._which_careamist()
        if careamist is None:
            self.pred_status.state = PredictionState.STOPPED
            return

        if state == PredictionState.PREDICTING:
            # get the prediction data
            data_source = self.prediction_widget.get_data_source()
            if data_source is None:
                ntf.show_info("Please set the prediction data first.")
                self.pred_status.state = PredictionState.IDLE
                self.prediction_widget.predict_button.setText("Predict")
                return

            # update configuration from ui
            self.update_config()

            # start the prediction thread
            self.pred_worker = predict_worker(
                careamist,
                data_source,
                self.careamics_config,
                self._prediction_queue,
            )
            self.pred_worker.yielded.connect(self._update_from_prediction)
            self.pred_worker.start()

        elif state == PredictionState.STOPPED:
            # prediction stopped: reset the progress bar
            self._prediction_queue.put(
                PredictionUpdate(PredictionUpdateType.SAMPLE_IDX, -1)
            )

    def _on_careamist_loaded(self, careamist: CAREamist) -> None:
        """Event handler called when a CAREamics instance has been loaded."""
        self.careamist_loaded = careamist
        print(
            f"CAREamics instance loaded: "
            f"{self.careamist_loaded.cfg.get_algorithm_friendly_name()}"
        )
        if _has_napari:
            ntf.show_info("CAREamics model loaded successfully!")

    def _model_selection_changed(self, from_disk: bool) -> None:
        """Event handler called when user changed the model selection."""
        # update the prediction and stop buttons
        if not from_disk:
            self.prediction_widget.update_button_from_train(self.train_status.state)
        elif self.careamist_loaded is not None:
            self.prediction_widget.predict_button.setEnabled(True)
            self.prediction_widget.stop_button.setEnabled(False)

    def _which_careamist(self) -> CAREamist | None:
        """Which careamist to use? Trained one or the loaded one."""
        # if load from disk option is selected
        if self.prediction_widget.load_from_disk:
            careamist = self.careamist_loaded
            if careamist is None:
                ntf.show_warning("No model was loaded from disk!")
        else:
            careamist = self.careamist
            if careamist is None:
                ntf.show_warning("No trained model is available.")

        return careamist

    def _update_from_training(self, update: TrainUpdate) -> None:
        """Update the training status from the training worker.

        This method receives the updates from the training worker.

        Parameters
        ----------
        update : TrainUpdate
            Update.
        """
        if update.type == TrainUpdateType.CAREAMIST:
            if isinstance(update.value, CAREamist):
                self.careamist = update.value
        elif update.type == TrainUpdateType.DEBUG:
            print(update.value)
        elif update.type == TrainUpdateType.EXCEPTION:
            self.train_status.state = TrainingState.CRASHED

            if isinstance(update.value, Exception):
                raise update.value
        else:
            self.train_status.update(update)

    def _update_from_prediction(self, update: PredictionUpdate) -> None:
        """Update the signal from the prediction worker.

        This method receives the updates from the prediction worker.

        Parameters
        ----------
        update : PredictionUpdate
            Update.
        """
        if update.type == PredictionUpdateType.DEBUG:
            print(update.value)
        elif update.type == PredictionUpdateType.EXCEPTION:
            self.pred_status.state = PredictionState.CRASHED
            # print exception without raising it
            print(f"Error: {update.value}")
            if _has_napari:
                ntf.show_error(
                    f"An error occurred during prediction: \n {update.value} \n"
                    f"Note: if you get an error due to the sizes of "
                    f"Tensors, try using tiling."
                )
        else:
            if update.type == PredictionUpdateType.SAMPLE:
                # add image to napari
                # TODO keep scaling?
                if self.viewer is not None:
                    # value is either a numpy array or
                    # a list of numpy arrays with each sample/time-point as an element
                    if isinstance(update.value, list):
                        # combine all samples
                        samples = np.concatenate(update.value, axis=0)
                    else:
                        samples = update.value

                    # reshape the prediction to match the input axes
                    samples = reshape_prediction(
                        samples,  # type: ignore
                        self.careamics_config.data_config.axes,  # type: ignore
                        self.careamics_config.is_3D,
                    )
                    self.viewer.add_image(samples, name="Prediction")
            else:
                self.pred_status.update(update)

    def _show_bmz_dialog(
        self, bmz_path: Path, cover: Path, sample_input: np.ndarray
    ) -> None:
        """Show the BMZ export dialog window."""
        # ask user for bmz model specs
        bmz_window = BMZExportWidget(self, cover_image_path=cover)
        bmz_window.accepted.connect(
            lambda: self._export_to_bmz(bmz_window, bmz_path, sample_input)
        )
        bmz_window.show()

    def _prepare_export_to_bmz(self, destination: Path, name: str) -> None:
        """Export the trained model to BMZ format."""
        if self.careamist is None:
            if _has_napari:
                ntf.show_info("No trained model is available for exporting.")
            return

        bmz_path = destination.joinpath(name + ".zip")

        data_sources = self.input_data_widget.get_data_sources()
        if data_sources is not None:
            train_data = data_sources["train"][0]
            if not isinstance(train_data, np.ndarray):
                raise NotImplementedError(
                    "BMZ export from tiff data source is not implemented yet."
                )
        if train_data.ndim == 2:
            sample_input = train_data[:256, :256]
        else:
            sample_input = train_data[0, :256, :256]

        # make a default cover image
        output_patches = self.careamist.predict(
            sample_input,
            data_type="array",
            tta_transforms=False,
        )
        sample_output = np.concatenate(output_patches, axis=0)
        cover_path = create_cover(
            directory=self.careamics_config.work_dir,
            array_in=sample_input[np.newaxis, np.newaxis, ...],
            array_out=sample_output,
        )

        # show the bmz export dialog
        self._show_bmz_dialog(bmz_path, cover_path, sample_input)

    def _export_to_bmz(
        self, bmz_window: BMZExportWidget, bmz_path: Path, sample_input: np.ndarray
    ) -> None:
        bmz_data = {
            "model_name": bmz_window.model_name,
            "description": bmz_window.general_description,
            "data_description": bmz_window.data_description,
            "authors": bmz_window.authors,
            "cover": bmz_window.cover_image,
        }

        try:
            self.careamist.export_to_bmz(  # type: ignore
                path_to_archive=bmz_path,
                input_array=sample_input,
                friendly_model_name=bmz_data["model_name"],
                general_description=bmz_data["description"],
                data_description=bmz_data["data_description"],
                authors=bmz_data["authors"],
                covers=[bmz_data["cover"]],
            )
            print(f"Model exported at {bmz_path}")
            if _has_napari:
                ntf.show_info(f"Model exported at {bmz_path}")

        except Exception as e:
            traceback.print_exc()
            if _has_napari:
                ntf.show_error(str(e))

    def closeEvent(self, event) -> None:
        """Close the plugin.

        Parameters
        ----------
        event : QCloseEvent
            Close event.
        """
        super().closeEvent(event)

__init__(napari_viewer=None) #

Initialize the plugin.

Parameters:

Name Type Description Default
napari_viewer Viewer or None

Napari viewer.

None
Source code in src/careamics_napari/base_plugin.py
def __init__(
    self,
    napari_viewer: napari.Viewer | None = None,
) -> None:
    """Initialize the plugin.

    Parameters
    ----------
    napari_viewer : napari.Viewer or None, default=None
        Napari viewer.
    """
    super().__init__()
    self.viewer = napari_viewer
    self.careamist: CAREamist | None = None  # to hold trained careamist
    self.careamist_loaded: CAREamist | None = None  # to hold loaded careamist

    # create statuses, used to keep track of the threads statuses
    self.train_status = TrainingStatus()  # type: ignore
    self.pred_status = PredictionStatus()  # type: ignore

    # create a careamics config (n2v by default)
    self.careamics_config = get_default_n2v_config()

    # create queues, used to communicate between the threads and the UI
    self._training_queue: Queue = Queue(10)
    self._prediction_queue: Queue = Queue(10)

    # changes from the training and prediction
    self.train_status.events.state.connect(self._training_state_changed)
    self.pred_status.events.state.connect(self._prediction_state_changed)

    # main layout
    self.base_layout = QVBoxLayout()
    # scrolling content
    scroll_content = QWidget()
    scroll_content.setLayout(self.base_layout)
    scroll = ScrollWidgetWrapper(scroll_content)
    vbox = QVBoxLayout()
    vbox.addWidget(scroll)
    self.setLayout(vbox)
    self.setMinimumWidth(200)

add_careamics_banner(desc='') #

Add the CAREamics banner and GPU label to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_careamics_banner(self, desc: str = "") -> None:
    """Add the CAREamics banner and GPU label to the plugin."""
    if len(desc) == 0:
        desc = "CAREamics UI for training denoising models."
    self.base_layout.addWidget(
        CAREamicsBanner(
            title="CAREamics",
            short_desc=(desc),
        )
    )
    # GPU label
    gpu_button = create_gpu_label()
    gpu_button.setAlignment(Qt.AlignmentFlag.AlignRight)
    gpu_button.setContentsMargins(0, 5, 0, 0)  # top margin
    self.base_layout.addWidget(gpu_button)

add_config_ui() #

Add the training configuration UI to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_config_ui(self) -> None:
    """Add the training configuration UI to the plugin."""
    self.config_widget = ConfigurationWidget(self.careamics_config)
    self.config_widget.enable_3d_chkbox.clicked.connect(self._set_pred_3d)
    self.config_widget.show_advanced_config.connect(self.show_advanced_config)
    self.base_layout.addWidget(self.config_widget)

add_model_export_ui() #

Add the model saving UI to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_model_export_ui(self) -> None:
    """Add the model saving UI to the plugin."""
    self.saving_widget = SavingWidget(
        self.careamics_config,
        self.train_status,
    )
    self.saving_widget.export_model.connect(self.export_model)
    self.base_layout.addWidget(self.saving_widget)

add_prediction_ui() #

Add the prediction UI to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_prediction_ui(self) -> None:
    """Add the prediction UI to the plugin."""
    self.prediction_widget = PredictionWidget(
        self.careamics_config,
        self.train_status,
        self.pred_status,
        self._prediction_queue,
    )
    self.base_layout.addWidget(self.prediction_widget)
    # to get loaded careamist
    self.prediction_widget.careamist_loaded.connect(self._on_careamist_loaded)
    self.prediction_widget.model_from_disk.connect(self._model_selection_changed)

add_train_button_ui() #

Add the training button UI to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_train_button_ui(self) -> None:
    """Add the training button UI to the plugin."""
    self.train_widget = TrainingWidget(self.train_status)
    self.progress_widget = TrainProgressWidget(
        self.careamics_config, self.train_status
    )
    self.base_layout.addWidget(self.train_widget)
    self.base_layout.addWidget(self.progress_widget)

add_train_input_ui(use_target=False) #

Add the train input data selection UI to the plugin.

Source code in src/careamics_napari/base_plugin.py
def add_train_input_ui(self, use_target: bool = False) -> None:
    """Add the train input data selection UI to the plugin."""
    self.input_data_widget = TrainDataWidget(
        careamics_config=self.careamics_config, use_target=use_target
    )
    self.base_layout.addWidget(self.input_data_widget)

closeEvent(event) #

Close the plugin.

Parameters:

Name Type Description Default
event QCloseEvent

Close event.

required
Source code in src/careamics_napari/base_plugin.py
def closeEvent(self, event) -> None:
    """Close the plugin.

    Parameters
    ----------
    event : QCloseEvent
        Close event.
    """
    super().closeEvent(event)

export_model(destination, export_type) #

Export the trained model.

Source code in src/careamics_napari/base_plugin.py
def export_model(self, destination: Path, export_type: str) -> None:
    """Export the trained model."""
    if self.careamist is None:
        if _has_napari:
            ntf.show_info("No trained model is available for exporting.")
        return

    dims = "3D" if self.careamics_config.is_3D else "2D"
    algo_name = self.careamics_config.algorithm_config.get_algorithm_friendly_name()
    name = f"{algo_name}_{dims}_{self.careamics_config.experiment_name}"

    try:
        if export_type == ExportType.BMZ.value:
            self._prepare_export_to_bmz(destination, name)
        else:
            name = name + ".ckpt"
            self.careamist.trainer.save_checkpoint(
                destination.joinpath(name),
            )
            print(f"Model exported at {destination}")
            if _has_napari:
                ntf.show_info(f"Model exported at {destination}")

    except Exception as e:
        traceback.print_exc()
        if _has_napari:
            ntf.show_error(str(e))

show_advanced_config() #

Show advanced configuration options.

Source code in src/careamics_napari/base_plugin.py
def show_advanced_config(self):
    """Show advanced configuration options."""
    raise NotImplementedError("Advanced configuration options are not implemented.")

update_config() #

Update the configuration from the UI.

Source code in src/careamics_napari/base_plugin.py
def update_config(self) -> None:
    """Update the configuration from the UI."""
    if self.config_widget is not None:
        self.config_widget.update_config()

    if self.prediction_widget is not None:
        self.prediction_widget.update_config()

    print(f"update_config:\n{self.careamics_config}")