training_model
Training configuration.
TrainingConfig
#
Bases: BaseModel
Parameters related to the training.
Mandatory parameters are: - num_epochs: number of epochs, greater than 0. - batch_size: batch size, greater than 0. - augmentation: whether to use data augmentation or not (True or False).
Attributes:
Name | Type | Description |
---|---|---|
num_epochs | int | Number of epochs, greater than 0. |
Source code in src/careamics/config/training_model.py
accumulate_grad_batches: int = Field(default=1, ge=1)
class-attribute
instance-attribute
#
Number of batches to accumulate gradients over before stepping the optimizer.
check_val_every_n_epoch: int = Field(default=1, ge=1)
class-attribute
instance-attribute
#
Validation step frequency.
checkpoint_callback: CheckpointModel = CheckpointModel()
class-attribute
instance-attribute
#
Checkpoint callback configuration, following PyTorch Lightning Checkpoint callback.
early_stopping_callback: Optional[EarlyStoppingModel] = Field(default=None, validate_default=True)
class-attribute
instance-attribute
#
Early stopping callback configuration, following PyTorch Lightning Checkpoint callback.
enable_progress_bar: bool = Field(default=True)
class-attribute
instance-attribute
#
Whether to enable the progress bar.
gradient_clip_algorithm: Literal['value', 'norm'] = 'norm'
class-attribute
instance-attribute
#
The algorithm to use for gradient clipping (see lightning Trainer
).
gradient_clip_val: Optional[Union[int, float]] = None
class-attribute
instance-attribute
#
The value to which to clip the gradient
logger: Optional[Literal['wandb', 'tensorboard']] = None
class-attribute
instance-attribute
#
Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.
max_steps: int = Field(default=-1, ge=-1)
class-attribute
instance-attribute
#
Maximum number of steps to train for. -1 means no limit.
num_epochs: int = Field(default=20, ge=1)
class-attribute
instance-attribute
#
Number of epochs, greater than 0.
precision: Literal['64', '32', '16-mixed', 'bf16-mixed'] = Field(default='32')
class-attribute
instance-attribute
#
Numerical precision
__str__()
#
Pretty string reprensenting the configuration.
Returns:
Type | Description |
---|---|
str | Pretty string. |
has_logger()
#
Check if the logger is defined.
Returns:
Type | Description |
---|---|
bool | Whether the logger is defined or not. |
validate_max_steps(max_steps)
classmethod
#
Validate the max_steps parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_steps | int | Maximum number of steps to train for. -1 means no limit. | required |
Returns:
Type | Description |
---|---|
int | Validated max_steps. |