class TrainPlugin(QWidget):
"""CAREamics training plugin.
Parameters
----------
napari_viewer : napari.Viewer or None, default=None
Napari viewer.
"""
def __init__(
self: Self,
napari_viewer: Optional[napari.Viewer] = None,
) -> None:
"""Initialize the plugin.
Parameters
----------
napari_viewer : napari.Viewer or None, default=None
Napari viewer.
"""
super().__init__()
self.viewer = napari_viewer
self.careamist: Optional[CAREamist] = None
# create statuses, used to keep track of the threads statuses
self.train_status = TrainingStatus() # type: ignore
self.pred_status = PredictionStatus() # type: ignore
self.save_status = SavingStatus() # type: ignore
# create signals, used to hold the various parameters modified by the UI
self.train_config_signal = TrainingSignal() # type: ignore
self.pred_config_signal = PredictionSignal()
self.save_config_signal = SavingSignal()
self.train_config_signal.events.is_3d.connect(self._set_pred_3d)
# create queues, used to communicate between the threads and the UI
self._training_queue: Queue = Queue(10)
self._prediction_queue: Queue = Queue(10)
# set workdir
self.train_config_signal.work_dir = Path.cwd()
self._init_ui()
def _init_ui(self) -> None:
"""Assemble the widgets."""
# layout
self.setLayout(QVBoxLayout())
self.setMinimumWidth(200)
# add banner
self.layout().addWidget(
CAREamicsBanner(
title="CAREamics",
short_desc=("CAREamics UI for training denoising models."),
)
)
# add GPU label and algorithm selection
algo_panel = QWidget()
algo_panel.setLayout(QHBoxLayout())
gpu_button = create_gpu_label()
gpu_button.setAlignment(Qt.AlignmentFlag.AlignRight)
gpu_button.setContentsMargins(0, 5, 0, 0) # top margin
algo_choice = AlgorithmSelectionWidget(training_signal=self.train_config_signal)
gpu_button.setAlignment(Qt.AlignmentFlag.AlignLeft)
algo_panel.layout().addWidget(algo_choice)
algo_panel.layout().addWidget(gpu_button)
self.layout().addWidget(algo_panel)
# add data tabs
self.data_stck = QStackedWidget()
self.data_layers = [
TrainDataWidget(signal=self.train_config_signal),
TrainDataWidget(signal=self.train_config_signal, use_target=True),
]
for layer in self.data_layers:
self.data_stck.addWidget(layer)
self.data_stck.setCurrentIndex(0)
self.layout().addWidget(self.data_stck)
# add configuration widget
self.config_widget = ConfigurationWidget(self.train_config_signal)
self.layout().addWidget(self.config_widget)
# add train widget
self.train_widget = TrainingWidget(self.train_status)
self.layout().addWidget(self.train_widget)
# add progress widget
self.progress_widget = TrainProgressWidget(
self.train_status, self.train_config_signal
)
self.layout().addWidget(self.progress_widget)
# add prediction
self.prediction_widget = PredictionWidget(
self.train_status,
self.pred_status,
self.train_config_signal,
self.pred_config_signal,
)
self.layout().addWidget(self.prediction_widget)
# add saving
self.saving_widget = SavingWidget(
train_status=self.train_status,
save_status=self.save_status,
save_signal=self.save_config_signal,
)
self.layout().addWidget(self.saving_widget)
# connect signals
# changes from the selected algorithm
self.train_config_signal.events.algorithm.connect(self._set_data_from_algorithm)
self._set_data_from_algorithm(
self.train_config_signal.algorithm
) # force update
# changes from the training, prediction or saving state
self.train_status.events.state.connect(self._training_state_changed)
self.pred_status.events.state.connect(self._prediction_state_changed)
self.save_status.events.state.connect(self._saving_state_changed)
def _set_pred_3d(self, is_3d: bool) -> None:
"""Set the 3D mode flag in the prediction signal.
Parameters
----------
is_3d : bool
3D mode.
"""
self.pred_config_signal.is_3d = is_3d
def _training_state_changed(self, state: TrainingState) -> None:
"""Handle training state changes.
This includes starting and stopping training.
Parameters
----------
state : TrainingState
New state.
"""
if state == TrainingState.TRAINING:
self.train_worker = train_worker(
self.train_config_signal,
self._training_queue,
self._prediction_queue,
self.careamist,
)
self.train_worker.yielded.connect(self._update_from_training)
self.train_worker.start()
elif state == TrainingState.STOPPED:
if self.careamist is not None:
self.careamist.stop_training()
elif state == TrainingState.CRASHED or state == TrainingState.IDLE:
del self.careamist
self.careamist = None
def _prediction_state_changed(self, state: PredictionState) -> None:
"""Handle prediction state changes.
Parameters
----------
state : PredictionState
New state.
"""
if state == PredictionState.PREDICTING:
self.pred_worker = predict_worker(
self.careamist, self.pred_config_signal, self._prediction_queue
)
self.pred_worker.yielded.connect(self._update_from_prediction)
self.pred_worker.start()
elif state == PredictionState.STOPPED:
# self.careamist.stop_prediction()
# TODO not existing yet
pass
def _saving_state_changed(self, state: SavingState) -> None:
"""Handle saving state changes.
Parameters
----------
state : SavingState
New state.
"""
if state == SavingState.SAVING:
self.save_worker = save_worker(
self.careamist, self.train_config_signal, self.save_config_signal
)
self.save_worker.yielded.connect(self._update_from_saving)
self.save_worker.start()
def _update_from_training(self, update: TrainUpdate) -> None:
"""Update the training status from the training worker.
This method receives the updates from the training worker.
Parameters
----------
update : TrainUpdate
Update.
"""
if update.type == TrainUpdateType.CAREAMIST:
if isinstance(update.value, CAREamist):
self.careamist = update.value
elif update.type == TrainUpdateType.DEBUG:
print(update.value)
elif update.type == TrainUpdateType.EXCEPTION:
self.train_status.state = TrainingState.CRASHED
if isinstance(update.value, Exception):
raise update.value
else:
self.train_status.update(update)
def _update_from_prediction(self, update: PredictionUpdate) -> None:
"""Update the signal from the prediction worker.
This method receives the updates from the prediction worker.
Parameters
----------
update : PredictionUpdate
Update.
"""
if update.type == PredictionUpdateType.DEBUG:
print(update.value)
elif update.type == PredictionUpdateType.EXCEPTION:
self.pred_status.state = PredictionState.CRASHED
# print exception without raising it
print(f"Error: {update.value}")
if _has_napari:
ntf.show_error(
f"An error occurred during prediction: \n {update.value} \n"
f"Note: if you get an error due to the sizes of "
f"Tensors, try using tiling."
)
else:
if update.type == PredictionUpdateType.SAMPLE:
# add image to napari
# TODO keep scaling?
if self.viewer is not None:
# value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element
if isinstance(update.value, list):
# combine all samples
samples = np.concatenate(update.value, axis=0)
else:
samples = update.value
# reshape the prediction to match the input axes
samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d)
self.viewer.add_image(samples, name="Prediction")
else:
self.pred_status.update(update)
def _update_from_saving(self, update: SavingUpdate) -> None:
"""Update the signal from the saving worker.
This method receives the updates from the saving worker.
Parameters
----------
update : SavingUpdate
Update.
"""
if update.type == SavingUpdateType.DEBUG:
print(update.value)
elif update.type == SavingUpdateType.EXCEPTION:
self.save_status.state = SavingState.CRASHED
if _has_napari:
ntf.show_error(f"An error occurred during saving: \n {update.value}")
def _set_data_from_algorithm(self, name: str) -> None:
"""Update the data selection widget based on the algorithm.
Parameters
----------
name : str
Algorithm name.
"""
if (
name == SupportedAlgorithm.CARE.value
or name == SupportedAlgorithm.N2N.value
):
self.data_stck.setCurrentIndex(1)
else:
self.data_stck.setCurrentIndex(0)
def closeEvent(self, event) -> None:
"""Close the plugin.
Parameters
----------
event : QCloseEvent
Close event.
"""
super().closeEvent(event)