Training Factory
Convenience functions to create training configurations.
create_training_configuration(algorithm, trainer_params, logger, checkpoint_params=None, early_stopping_params=None, monitor_metric='val_loss')
Create a dictionary with the parameters of the training model.
Parameters:
-
algorithm((care, n2n, n2v), default:"care") –Algorithm type, used to select the default checkpointing preset.
-
trainer_params(dict) –Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
-
logger((wandb, tensorboard, none), default:"wandb") –Logger to use.
-
checkpoint_params(dict, default:None) –Parameters for the checkpoint callback, see PyTorch Lightning documentation (
ModelCheckpoint) for the list of available parameters. IfNone, then default parameters are applied. -
early_stopping_params(dict, default:None) –Parameters for the early stopping callback, see PyTorch Lightning documentation (
EarlyStopping) for the list of available parameters. IfNone, then default parameters are applied. -
monitor_metric(str, default:"val_loss") –Metric to monitor for early stopping.
Returns:
-
TrainingConfig–Training configuration with the specified parameters.
update_trainer_params(trainer_params=None, num_epochs=None, num_steps=None)
Update trainer parameters with num_epochs and num_steps.
Parameters:
-
trainer_params(dict, default:None) –Parameters for Lightning Trainer class, by default None.
-
num_epochs(int, default:None) –Number of epochs to train for. If provided, this will be added as max_epochs to trainer_params, by default None.
-
num_steps(int, default:None) –Number of batches in 1 epoch. If provided, this will be added as limit_train_batches to trainer_params, by default None.
Returns:
-
dict–Updated trainer parameters dictionary.