Skip to content

Training Factory

Source

Convenience functions to create training configurations.

create_training_configuration(algorithm, trainer_params, logger, checkpoint_params=None, early_stopping_params=None, monitor_metric='val_loss')

Create a dictionary with the parameters of the training model.

Parameters:

  • algorithm ((care, n2n, n2v), default: "care" ) –

    Algorithm type, used to select the default checkpointing preset.

  • trainer_params (dict) –

    Parameters for Lightning Trainer class, see PyTorch Lightning documentation.

  • logger ((wandb, tensorboard, none), default: "wandb" ) –

    Logger to use.

  • checkpoint_params (dict, default: None ) –

    Parameters for the checkpoint callback, see PyTorch Lightning documentation (ModelCheckpoint) for the list of available parameters. If None, then default parameters are applied.

  • early_stopping_params (dict, default: None ) –

    Parameters for the early stopping callback, see PyTorch Lightning documentation (EarlyStopping) for the list of available parameters. If None, then default parameters are applied.

  • monitor_metric (str, default: "val_loss" ) –

    Metric to monitor for early stopping.

Returns:

  • TrainingConfig

    Training configuration with the specified parameters.

update_trainer_params(trainer_params=None, num_epochs=None, num_steps=None)

Update trainer parameters with num_epochs and num_steps.

Parameters:

  • trainer_params (dict, default: None ) –

    Parameters for Lightning Trainer class, by default None.

  • num_epochs (int, default: None ) –

    Number of epochs to train for. If provided, this will be added as max_epochs to trainer_params, by default None.

  • num_steps (int, default: None ) –

    Number of batches in 1 epoch. If provided, this will be added as limit_train_batches to trainer_params, by default None.

Returns:

  • dict

    Updated trainer parameters dictionary.