Skip to content

training_worker

A thread worker function running CAREamics training.

train_worker(train_config_signal, training_queue, predict_queue, careamist=None) #

Model training worker.

Parameters:

Name Type Description Default
train_config_signal TrainingSignal

Training signal.

required
training_queue Queue

Training update queue.

required
predict_queue Queue

Prediction update queue.

required
careamist CAREamist or None

CAREamist instance.

None

Yields:

Type Description
Generator[TrainUpdate, None, None]

Updates.

Source code in src/careamics_napari/workers/training_worker.py
@thread_worker
def train_worker(
    train_config_signal: TrainingSignal,
    training_queue: Queue,
    predict_queue: Queue,
    careamist: Optional[CAREamist] = None,
) -> Generator[TrainUpdate, None, None]:
    """Model training worker.

    Parameters
    ----------
    train_config_signal : TrainingSignal
        Training signal.
    training_queue : Queue
        Training update queue.
    predict_queue : Queue
        Prediction update queue.
    careamist : CAREamist or None, default=None
        CAREamist instance.

    Yields
    ------
    Generator[TrainUpdate, None, None]
        Updates.
    """
    # start training thread
    training = Thread(
        target=_train,
        args=(
            train_config_signal,
            training_queue,
            predict_queue,
            careamist,
        ),
    )
    training.start()

    # look for updates
    while True:
        update: TrainUpdate = training_queue.get(block=True)

        yield update

        if (
            update.type == TrainUpdateType.STATE and update.value == TrainingState.DONE
        ) or (update.type == TrainUpdateType.EXCEPTION):
            break

    # wait for the other thread to finish
    training.join()