Skip to content

training_model

Training configuration.

TrainingConfig #

Bases: BaseModel

Parameters related to the training.

Mandatory parameters are: - num_epochs: number of epochs, greater than 0. - batch_size: batch size, greater than 0. - augmentation: whether to use data augmentation or not (True or False).

Attributes:

Name Type Description
num_epochs int

Number of epochs, greater than 0.

Source code in src/careamics/config/training_model.py
class TrainingConfig(BaseModel):
    """
    Parameters related to the training.

    Mandatory parameters are:
        - num_epochs: number of epochs, greater than 0.
        - batch_size: batch size, greater than 0.
        - augmentation: whether to use data augmentation or not (True or False).

    Attributes
    ----------
    num_epochs : int
        Number of epochs, greater than 0.
    """

    # Pydantic class configuration
    model_config = ConfigDict(
        validate_assignment=True,
    )
    lightning_trainer_config: dict | None = None
    """Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
    Trainer class"""

    logger: Literal["wandb", "tensorboard"] | None = None
    """Logger to use during training. If None, no logger will be used. Available
    loggers are defined in SupportedLogger."""

    # Only basic callbacks
    checkpoint_callback: CheckpointModel = CheckpointModel()
    """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
    callback."""

    early_stopping_callback: EarlyStoppingModel | None = Field(
        default=None, validate_default=True
    )
    """Early stopping callback configuration, following PyTorch Lightning Checkpoint
    callback."""

    def __str__(self) -> str:
        """Pretty string reprensenting the configuration.

        Returns
        -------
        str
            Pretty string.
        """
        return pformat(self.model_dump())

    def has_logger(self) -> bool:
        """Check if the logger is defined.

        Returns
        -------
        bool
            Whether the logger is defined or not.
        """
        return self.logger is not None

checkpoint_callback = CheckpointModel() class-attribute instance-attribute #

Checkpoint callback configuration, following PyTorch Lightning Checkpoint callback.

early_stopping_callback = Field(default=None, validate_default=True) class-attribute instance-attribute #

Early stopping callback configuration, following PyTorch Lightning Checkpoint callback.

lightning_trainer_config = None class-attribute instance-attribute #

Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning Trainer class

logger = None class-attribute instance-attribute #

Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.

__str__() #

Pretty string reprensenting the configuration.

Returns:

Type Description
str

Pretty string.

Source code in src/careamics/config/training_model.py
def __str__(self) -> str:
    """Pretty string reprensenting the configuration.

    Returns
    -------
    str
        Pretty string.
    """
    return pformat(self.model_dump())

has_logger() #

Check if the logger is defined.

Returns:

Type Description
bool

Whether the logger is defined or not.

Source code in src/careamics/config/training_model.py
def has_logger(self) -> bool:
    """Check if the logger is defined.

    Returns
    -------
    bool
        Whether the logger is defined or not.
    """
    return self.logger is not None