Skip to content

Optimizer Configs

Source

Optimizers and schedulers Pydantic models.

LrSchedulerConfig

Bases: BaseModel

Torch learning rate scheduler Pydantic model.

Only parameters supported by the corresponding torch lr scheduler will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

Note that mandatory parameters (see the specific LrScheduler signature in the link above) must be provided. For example, StepLR requires step_size.

Attributes:

Name Type Description
name {'ReduceLROnPlateau', 'StepLR'}

Name of the learning rate scheduler.

parameters dict

Parameters of the learning rate scheduler (see torch documentation).

name = Field(default='ReduceLROnPlateau') class-attribute instance-attribute

Name of the learning rate scheduler, supported schedulers are defined in SupportedScheduler.

parameters = Field(default={}, validate_default=True) class-attribute instance-attribute

Parameters of the learning rate scheduler, see PyTorch documentation for more details.

filter_parameters(user_params, values) classmethod

Filter parameters based on the learning rate scheduler's signature.

Parameters:

Name Type Description Default
user_params dict

User parameters.

required
values ValidationInfo

Pydantic field validation info, used to get the scheduler name.

required

Returns:

Type Description
dict

Filtered scheduler parameters.

Raises:

Type Description
ValueError

If the scheduler is StepLR and the step_size parameter is not specified.

OptimizerConfig

Bases: BaseModel

Torch optimizer Pydantic model.

Only parameters supported by the corresponding torch optimizer will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#algorithms

Note that mandatory parameters (see the specific Optimizer signature in the link above) must be provided. For example, SGD requires lr.

Attributes:

Name Type Description
name {'Adam', 'SGD', 'Adamax', 'AdamW'}

Name of the optimizer.

parameters dict

Parameters of the optimizer (see torch documentation).

name = Field(default='Adam', validate_default=True) class-attribute instance-attribute

Name of the optimizer, supported optimizers are defined in SupportedOptimizer.

parameters = Field(default={}, validate_default=True) class-attribute instance-attribute

Parameters of the optimizer, see PyTorch documentation for more details.

filter_parameters(user_params, values) classmethod

Validate optimizer parameters.

This method filters out unknown parameters, given the optimizer name.

Parameters:

Name Type Description Default
user_params dict

Parameters passed on to the torch optimizer.

required
values ValidationInfo

Pydantic field validation info, used to get the optimizer name.

required

Returns:

Type Description
dict

Filtered optimizer parameters.

Raises:

Type Description
ValueError

If the optimizer name is not specified.

sgd_lr_parameter()

Check that SGD optimizer has the mandatory lr parameter specified.

This is specific for PyTorch < 2.2.

Returns:

Type Description
Self

Validated optimizer.

Raises:

Type Description
ValueError

If the optimizer is SGD and the lr parameter is not specified.