Lightning
Training and lightning related Pydantic configurations.
CheckpointConfig
Bases: BaseModel
Checkpoint saving callback Pydantic model.
The parameters corresponds to those of
pytorch_lightning.callbacks.ModelCheckpoint.
See: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
auto_insert_metric_name = Field(default=False)
class-attribute
instance-attribute
When True, the checkpoints filenames will contain the metric name. Note that
val_loss is already embedded in the default filename pattern and enabling this
field will produce redundant metric names in the filename.
every_n_epochs = Field(default=None, ge=1, le=100)
class-attribute
instance-attribute
Number of epochs between checkpoints.
every_n_train_steps = Field(default=None, ge=1, le=1000)
class-attribute
instance-attribute
Number of training steps between checkpoints.
mode = Field(default='min')
class-attribute
instance-attribute
One of {min, max}. If save_top_k != 0, the decision to overwrite the current
save file is made based on either the maximization or the minimization of the
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
be 'min', etc.
monitor = Field(default='val_loss')
class-attribute
instance-attribute
Quantity to monitor, currently only val_loss.
save_last = Field(default=True)
class-attribute
instance-attribute
When True, saves a {experiment_name}_last.ckpt copy whenever a checkpoint
file gets saved.
save_top_k = Field(default=3, ge=(-1), le=100)
class-attribute
instance-attribute
If save_top_k == k, the best k models according to the quantity monitored
will be saved. Ifsave_top_k == 0, no models are saved. ifsave_top_k == -1`,
all models are saved.
save_weights_only = Field(default=False)
class-attribute
instance-attribute
When True, only the model's weights will be saved (model.save_weights).
train_time_interval = Field(default=None)
class-attribute
instance-attribute
Checkpoints are monitored at the specified time interval.
verbose = Field(default=False)
class-attribute
instance-attribute
Verbosity mode.
EarlyStoppingConfig
Bases: BaseModel
Early stopping callback Pydantic model.
The parameters corresponds to those of
pytorch_lightning.callbacks.ModelCheckpoint.
See: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
check_finite = Field(default=True)
class-attribute
instance-attribute
When True, stops training when the monitored quantity becomes NaN or
inf.
check_on_train_epoch_end = Field(default=False)
class-attribute
instance-attribute
Whether to run early stopping at the end of the training epoch. If this is
False, then the check runs at the end of the validation.
divergence_threshold = Field(default=None)
class-attribute
instance-attribute
Stop training as soon as the monitored quantity becomes worse than this threshold.
log_rank_zero_only = Field(default=False)
class-attribute
instance-attribute
When set True, logs the status of the early stopping callback only for rank 0
process.
min_delta = Field(default=0.0, ge=0.0, le=1.0)
class-attribute
instance-attribute
Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.
mode = Field(default='min')
class-attribute
instance-attribute
One of {min, max, auto}.
monitor = Field(default='val_loss')
class-attribute
instance-attribute
Quantity to monitor.
patience = Field(default=3, ge=1, le=10)
class-attribute
instance-attribute
Number of checks with no improvement after which training will be stopped.
stopping_threshold = Field(default=None)
class-attribute
instance-attribute
Stop training immediately once the monitored quantity reaches this threshold.
verbose = Field(default=False)
class-attribute
instance-attribute
Verbosity mode.
LrSchedulerConfig
Bases: BaseModel
Torch learning rate scheduler Pydantic model.
Only parameters supported by the corresponding torch lr scheduler will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
Note that mandatory parameters (see the specific LrScheduler signature in the
link above) must be provided. For example, StepLR requires step_size.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
{'ReduceLROnPlateau', 'StepLR'}
|
Name of the learning rate scheduler. |
parameters |
dict
|
Parameters of the learning rate scheduler (see torch documentation). |
name = Field(default='ReduceLROnPlateau')
class-attribute
instance-attribute
Name of the learning rate scheduler, supported schedulers are defined in SupportedScheduler.
parameters = Field(default={}, validate_default=True)
class-attribute
instance-attribute
Parameters of the learning rate scheduler, see PyTorch documentation for more details.
filter_parameters(user_params, values)
classmethod
Filter parameters based on the learning rate scheduler's signature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
user_params
|
dict
|
User parameters. |
required |
values
|
ValidationInfo
|
Pydantic field validation info, used to get the scheduler name. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Filtered scheduler parameters. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the scheduler is StepLR and the step_size parameter is not specified. |
OptimizerConfig
Bases: BaseModel
Torch optimizer Pydantic model.
Only parameters supported by the corresponding torch optimizer will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#algorithms
Note that mandatory parameters (see the specific Optimizer signature in the
link above) must be provided. For example, SGD requires lr.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
{'Adam', 'SGD', 'Adamax', 'AdamW'}
|
Name of the optimizer. |
parameters |
dict
|
Parameters of the optimizer (see torch documentation). |
name = Field(default='Adam', validate_default=True)
class-attribute
instance-attribute
Name of the optimizer, supported optimizers are defined in SupportedOptimizer.
parameters = Field(default={}, validate_default=True)
class-attribute
instance-attribute
Parameters of the optimizer, see PyTorch documentation for more details.
filter_parameters(user_params, values)
classmethod
Validate optimizer parameters.
This method filters out unknown parameters, given the optimizer name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
user_params
|
dict
|
Parameters passed on to the torch optimizer. |
required |
values
|
ValidationInfo
|
Pydantic field validation info, used to get the optimizer name. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Filtered optimizer parameters. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the optimizer name is not specified. |
sgd_lr_parameter()
Check that SGD optimizer has the mandatory lr parameter specified.
This is specific for PyTorch < 2.2.
Returns:
| Type | Description |
|---|---|
Self
|
Validated optimizer. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the optimizer is SGD and the lr parameter is not specified. |
TrainingConfig
Bases: BaseModel
Parameters related to the training.
Mandatory parameters are: - num_epochs: number of epochs, greater than 0. - batch_size: batch size, greater than 0. - augmentation: whether to use data augmentation or not (True or False).
Attributes:
| Name | Type | Description |
|---|---|---|
num_epochs |
int
|
Number of epochs, greater than 0. |
checkpoint_callback = CheckpointConfig()
class-attribute
instance-attribute
Checkpoint callback configuration, following PyTorch Lightning Checkpoint callback.
early_stopping_callback = Field(default=None, validate_default=True)
class-attribute
instance-attribute
Early stopping callback configuration, following PyTorch Lightning Checkpoint callback.
lightning_trainer_config = Field(default={})
class-attribute
instance-attribute
Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning Trainer class
logger = None
class-attribute
instance-attribute
Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.
has_logger()
Check if the logger is defined.
Returns:
| Type | Description |
|---|---|
bool
|
Whether the logger is defined or not. |