Optimizer Configs
Optimizers and schedulers Pydantic models.
LrSchedulerConfig
Bases: BaseModel
Torch learning rate scheduler Pydantic model.
Only parameters supported by the corresponding torch lr scheduler will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
Note that mandatory parameters (see the specific LrScheduler signature in the
link above) must be provided. For example, StepLR requires step_size.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
{'ReduceLROnPlateau', 'StepLR'}
|
Name of the learning rate scheduler. |
parameters |
dict
|
Parameters of the learning rate scheduler (see torch documentation). |
name = Field(default='ReduceLROnPlateau')
class-attribute
instance-attribute
Name of the learning rate scheduler, supported schedulers are defined in SupportedScheduler.
parameters = Field(default={}, validate_default=True)
class-attribute
instance-attribute
Parameters of the learning rate scheduler, see PyTorch documentation for more details.
filter_parameters(user_params, values)
classmethod
Filter parameters based on the learning rate scheduler's signature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
user_params
|
dict
|
User parameters. |
required |
values
|
ValidationInfo
|
Pydantic field validation info, used to get the scheduler name. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Filtered scheduler parameters. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the scheduler is StepLR and the step_size parameter is not specified. |
OptimizerConfig
Bases: BaseModel
Torch optimizer Pydantic model.
Only parameters supported by the corresponding torch optimizer will be taken into account. For more details, check: https://pytorch.org/docs/stable/optim.html#algorithms
Note that mandatory parameters (see the specific Optimizer signature in the
link above) must be provided. For example, SGD requires lr.
Attributes:
| Name | Type | Description |
|---|---|---|
name |
{'Adam', 'SGD', 'Adamax', 'AdamW'}
|
Name of the optimizer. |
parameters |
dict
|
Parameters of the optimizer (see torch documentation). |
name = Field(default='Adam', validate_default=True)
class-attribute
instance-attribute
Name of the optimizer, supported optimizers are defined in SupportedOptimizer.
parameters = Field(default={}, validate_default=True)
class-attribute
instance-attribute
Parameters of the optimizer, see PyTorch documentation for more details.
filter_parameters(user_params, values)
classmethod
Validate optimizer parameters.
This method filters out unknown parameters, given the optimizer name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
user_params
|
dict
|
Parameters passed on to the torch optimizer. |
required |
values
|
ValidationInfo
|
Pydantic field validation info, used to get the optimizer name. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Filtered optimizer parameters. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the optimizer name is not specified. |
sgd_lr_parameter()
Check that SGD optimizer has the mandatory lr parameter specified.
This is specific for PyTorch < 2.2.
Returns:
| Type | Description |
|---|---|
Self
|
Validated optimizer. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the optimizer is SGD and the lr parameter is not specified. |