Skip to content

Lightning

Source

Training and lightning related Pydantic configurations.

CheckpointConfig

Bases: BaseModel

Checkpoint saving callback Pydantic model.

The parameters corresponds to those of pytorch_lightning.callbacks.ModelCheckpoint.

See: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint

auto_insert_metric_name = Field(default=False) class-attribute instance-attribute

When True, the checkpoints filenames will contain the metric name. Note that val_loss is already embedded in the default filename pattern and enabling this field will produce redundant metric names in the filename.

every_n_epochs = Field(default=None, ge=1, le=100) class-attribute instance-attribute

Number of epochs between checkpoints.

every_n_train_steps = Field(default=None, ge=1, le=1000) class-attribute instance-attribute

Number of training steps between checkpoints.

mode = Field(default='min') class-attribute instance-attribute

One of {min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.

monitor = Field(default='val_loss') class-attribute instance-attribute

Quantity to monitor, currently only val_loss.

save_last = Field(default=True) class-attribute instance-attribute

When True, saves a {experiment_name}_last.ckpt copy whenever a checkpoint file gets saved.

save_top_k = Field(default=3, ge=(-1), le=100) class-attribute instance-attribute

If save_top_k == k, the best k models according to the quantity monitored will be saved. Ifsave_top_k == 0, no models are saved. ifsave_top_k == -1`, all models are saved.

save_weights_only = Field(default=False) class-attribute instance-attribute

When True, only the model's weights will be saved (model.save_weights).

train_time_interval = Field(default=None) class-attribute instance-attribute

Checkpoints are monitored at the specified time interval.

verbose = Field(default=False) class-attribute instance-attribute

Verbosity mode.

EarlyStoppingConfig

Bases: BaseModel

Early stopping callback Pydantic model.

The parameters corresponds to those of pytorch_lightning.callbacks.ModelCheckpoint.

See: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping

check_finite = Field(default=True) class-attribute instance-attribute

When True, stops training when the monitored quantity becomes NaN or inf.

check_on_train_epoch_end = Field(default=False) class-attribute instance-attribute

Whether to run early stopping at the end of the training epoch. If this is False, then the check runs at the end of the validation.

divergence_threshold = Field(default=None) class-attribute instance-attribute

Stop training as soon as the monitored quantity becomes worse than this threshold.

log_rank_zero_only = Field(default=False) class-attribute instance-attribute

When set True, logs the status of the early stopping callback only for rank 0 process.

min_delta = Field(default=0.0, ge=0.0, le=1.0) class-attribute instance-attribute

Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than or equal to min_delta, will count as no improvement.

mode = Field(default='min') class-attribute instance-attribute

One of {min, max, auto}.

monitor = Field(default='val_loss') class-attribute instance-attribute

Quantity to monitor.

patience = Field(default=3, ge=1, le=10) class-attribute instance-attribute

Number of checks with no improvement after which training will be stopped.

stopping_threshold = Field(default=None) class-attribute instance-attribute

Stop training immediately once the monitored quantity reaches this threshold.

verbose = Field(default=False) class-attribute instance-attribute

Verbosity mode.

LrSchedulerConfig

Bases: BaseModel

Torch learning rate scheduler Pydantic model.

Only parameters supported by the corresponding torch lr scheduler will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

Note that mandatory parameters (see the specific LrScheduler signature in the link above) must be provided. For example, StepLR requires step_size.

Attributes:

Name Type Description
name {'ReduceLROnPlateau', 'StepLR'}

Name of the learning rate scheduler.

parameters dict

Parameters of the learning rate scheduler (see torch documentation).

name = Field(default='ReduceLROnPlateau') class-attribute instance-attribute

Name of the learning rate scheduler, supported schedulers are defined in SupportedScheduler.

parameters = Field(default={}, validate_default=True) class-attribute instance-attribute

Parameters of the learning rate scheduler, see PyTorch documentation for more details.

filter_parameters(user_params, values) classmethod

Filter parameters based on the learning rate scheduler's signature.

Parameters:

Name Type Description Default
user_params dict

User parameters.

required
values ValidationInfo

Pydantic field validation info, used to get the scheduler name.

required

Returns:

Type Description
dict

Filtered scheduler parameters.

Raises:

Type Description
ValueError

If the scheduler is StepLR and the step_size parameter is not specified.

OptimizerConfig

Bases: BaseModel

Torch optimizer Pydantic model.

Only parameters supported by the corresponding torch optimizer will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#algorithms

Note that mandatory parameters (see the specific Optimizer signature in the link above) must be provided. For example, SGD requires lr.

Attributes:

Name Type Description
name {'Adam', 'SGD', 'Adamax', 'AdamW'}

Name of the optimizer.

parameters dict

Parameters of the optimizer (see torch documentation).

name = Field(default='Adam', validate_default=True) class-attribute instance-attribute

Name of the optimizer, supported optimizers are defined in SupportedOptimizer.

parameters = Field(default={}, validate_default=True) class-attribute instance-attribute

Parameters of the optimizer, see PyTorch documentation for more details.

filter_parameters(user_params, values) classmethod

Validate optimizer parameters.

This method filters out unknown parameters, given the optimizer name.

Parameters:

Name Type Description Default
user_params dict

Parameters passed on to the torch optimizer.

required
values ValidationInfo

Pydantic field validation info, used to get the optimizer name.

required

Returns:

Type Description
dict

Filtered optimizer parameters.

Raises:

Type Description
ValueError

If the optimizer name is not specified.

sgd_lr_parameter()

Check that SGD optimizer has the mandatory lr parameter specified.

This is specific for PyTorch < 2.2.

Returns:

Type Description
Self

Validated optimizer.

Raises:

Type Description
ValueError

If the optimizer is SGD and the lr parameter is not specified.

TrainingConfig

Bases: BaseModel

Parameters related to the training.

Mandatory parameters are: - num_epochs: number of epochs, greater than 0. - batch_size: batch size, greater than 0. - augmentation: whether to use data augmentation or not (True or False).

Attributes:

Name Type Description
num_epochs int

Number of epochs, greater than 0.

checkpoint_callback = CheckpointConfig() class-attribute instance-attribute

Checkpoint callback configuration, following PyTorch Lightning Checkpoint callback.

early_stopping_callback = Field(default=None, validate_default=True) class-attribute instance-attribute

Early stopping callback configuration, following PyTorch Lightning Checkpoint callback.

lightning_trainer_config = Field(default={}) class-attribute instance-attribute

Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning Trainer class

logger = None class-attribute instance-attribute

Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.

has_logger()

Check if the logger is defined.

Returns:

Type Description
bool

Whether the logger is defined or not.