training_factory
Convenience functions to create training configurations.
create_training_configuration(trainer_params, logger, checkpoint_params=None) #
Create a dictionary with the parameters of the training model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer_params | dict | Parameters for Lightning Trainer class, see PyTorch Lightning documentation. | required |
logger | (wandb, tensorboard, none) | Logger to use. | "wandb" |
checkpoint_params | dict | Parameters for the checkpoint callback, see PyTorch Lightning documentation ( | None |
Returns:
| Type | Description |
|---|---|
TrainingConfig | Training model with the specified parameters. |
Source code in src/careamics/config/ng_factories/training_factory.py
update_trainer_params(trainer_params=None, num_epochs=None, num_steps=None) #
Update trainer parameters with num_epochs and num_steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer_params | dict | Parameters for Lightning Trainer class, by default None. | None |
num_epochs | int | Number of epochs to train for. If provided, this will be added as max_epochs to trainer_params, by default None. | None |
num_steps | int | Number of batches in 1 epoch. If provided, this will be added as limit_train_batches to trainer_params, by default None. | None |
Returns:
| Type | Description |
|---|---|
dict | Updated trainer parameters dictionary. |