Skip to content

training_worker

A thread worker function running CAREamics training.

train_worker(configuration, data_sources, training_queue, predict_queue, careamist=None, pred_status=None) #

Model training worker.

Parameters:

Name Type Description Default
configuration BaseConfig

careamics configuration.

required
data_sources dict[str, list]

Train and validation data sources.

required
training_queue Queue

Training update queue.

required
predict_queue Queue

Prediction update queue.

required
careamist CAREamist or None

CAREamist instance.

None
pred_status PredictionStatus or None

Prediction status for stop callback.

None

Yields:

Type Description
Generator[TrainUpdate, None, None]

Updates.

Source code in src/careamics_napari/workers/training_worker.py
@thread_worker
def train_worker(
    configuration: BaseConfig,
    data_sources: dict[str, list],
    training_queue: Queue,
    predict_queue: Queue,
    careamist: CAREamist | None = None,
    pred_status: PredictionStatus | None = None,
) -> Generator[TrainUpdate, None, None]:
    """Model training worker.

    Parameters
    ----------
    configuration : BaseConfig
        careamics configuration.
    data_sources : dict[str, list]
        Train and validation data sources.
    training_queue : Queue
        Training update queue.
    predict_queue : Queue
        Prediction update queue.
    careamist : CAREamist or None, default=None
        CAREamist instance.
    pred_status : PredictionStatus or None, default=None
        Prediction status for stop callback.

    Yields
    ------
    Generator[TrainUpdate, None, None]
        Updates.
    """
    # start training thread
    training = Thread(
        target=_train,
        args=(
            configuration,
            data_sources,
            training_queue,
            predict_queue,
            careamist,
            pred_status,
        ),
    )
    training.start()

    # look for updates
    while True:
        update: TrainUpdate = training_queue.get(block=True)

        yield update

        if (
            update.type == TrainUpdateType.STATE and update.value == TrainingState.DONE
        ) or (update.type == TrainUpdateType.EXCEPTION):
            break

    # wait for the other thread to finish
    training.join()