Skip to content

Training Configuration

Source

Training configuration.

SelfSupervisedCheckpointing dataclass

Presets for checkpointing Noise2Noise and Noise2Void.

Because self-supervised algorithms are evaluating the loss against noisy pixels, its value is not a good measure of performances after a few epochs. Therefore, it cannot be used to evaluate the best performing models.

This presets saves checkpoints every 10 epochs, as well as the last one.

auto_insert_metric_name = False class-attribute instance-attribute

Do not insert the monitored value in the checkpoint name.

every_n_epochs = 10 class-attribute instance-attribute

Save a checkpoint every 10 epochs.

save_last = True class-attribute instance-attribute

Save the last checkpoint.

save_top_k = -1 class-attribute instance-attribute

Save all checkpoints. Checkpoints are checked every_n_epochs.

SupervisedCheckpointing dataclass

Presets for checkpointing CARE.

This preset saves the top 3 best performing checkpoints based on val_loss, as well as the last one.

auto_insert_metric_name = False class-attribute instance-attribute

Do not insert the monitored value in the checkpoint name.

mode = 'min' class-attribute instance-attribute

Top checkpoints are selected by minimum val_loss.

monitor = 'val_loss' class-attribute instance-attribute

Monitor val_loss.

save_last = True class-attribute instance-attribute

Save the last checkpoint.

save_top_k = 3 class-attribute instance-attribute

Save the top 3 best performing checkpoints.

TrainingConfig

Bases: BaseModel

Parameters related to the training.

By default, checkpoint_params and early_stopping_params have presets based on whether the algorithm is supervised (CARE) or not (Noise2Void and by extension Noise2Noise). In the case of CARE, the top 3 checkpoints are saved based on val_loss. For the self-supervised algorithms, checkpoints are saved every 10 epochs. In both cases, the last checkpoint is saved. Early stopping is disabled for self-supervised algorithms.

Attributes:

  • trainer_params (dict) –

    Parameters passed to the PyTorch Lightning Trainer class.

  • logger (Literal['wandb', 'tensorboard'] | None) –

    Additional Logger to use during training. If None, no logger will be used. Note that the CAREamist uses the csv logger regardless of the value of this field.

  • checkpoint_params (dict[str, Any]) –

    Checkpoint callback parameters, following PyTorch Lightning Checkpoint callback.

  • early_stopping_params (dict[str, Any] | None) –

    Early stopping callback parameters, following PyTorch Lightning Checkpoint callback.

checkpoint_params = Field(default_factory=dict) class-attribute instance-attribute

Checkpoint callback parameters, following PyTorch Lightning Checkpoint callback.

early_stopping_params = Field(default_factory=dict) class-attribute instance-attribute

Early stopping callback parameters, following PyTorch Lightning EarlyStopping callback.

logger = None class-attribute instance-attribute

Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.

trainer_params = Field(default={}) class-attribute instance-attribute

Parameters passed to the PyTorch Lightning Trainer class

__str__()

Pretty string reprensenting the configuration.

Returns:

  • str

    Pretty string.

valid_ckpt_parameters(user_params) classmethod

Validate parameters based on the checkpoint callback signature.

Parameters:

  • user_params (dict) –

    User parameters.

Returns:

  • dict

    Validated checkpoint parameters.

Raises:

  • ValueError

    If there are unknown parameters for the checkpoint callback.

valid_early_stopping_parameters(user_params) classmethod

Validate parameters based on the early stopping callback signature.

Parameters:

  • user_params (dict) –

    User parameters.

Returns:

  • dict

    Validated early stopping parameters.

Raises:

  • ValueError

    If there are unknown parameters for the early stopping callback.

valid_trainer_parameters(user_params) classmethod

Validate parameters based on the PyTorch Lightning Trainer signature.

Parameters:

  • user_params (dict) –

    User parameters.

Returns:

  • dict

    Validated trainer parameters.

Raises:

  • ValueError

    If there are unknown parameters for the PyTorch Lightning Trainer.

default_training_dict(algorithm, trainer_params=None, logger='none', checkpoint_params=None, early_stopping_params=None, monitor_metric='val_loss')

Default training configuration constructor.

This function sets default training parameters based on the algorithm configuration. If the user provides any of the parameters, they will take precedence over the defaults.

Parameters:

  • algorithm (('care', 'n2n', 'n2v'), default: "care" ) –

    Algorithm type, used to select the default checkpointing preset.

  • trainer_params (dict, default: None ) –

    Parameters for Lightning Trainer class, by default None.

  • logger (('wandb', 'tensorboard', 'none'), default: "wandb" ) –

    Logger to use, by default "none".

  • checkpoint_params (dict, default: None ) –

    Parameters for the checkpoint callback, by default None. If None, then default parameters are applied based on the algorithm.

  • early_stopping_params (dict, default: None ) –

    Parameters for the early stopping callback, by default None. If None, then default parameters are applied based on the algorithm.

  • monitor_metric (str, default: 'val_loss' ) –

    Metric to monitor for early stopping, by default "val_loss".

Returns:

  • dict

    Training configuration dictionary with the specified parameters.

default_training_factory(validated_dict)

Default training configuration constructor.

Parameters:

  • validated_dict (dict) –

    Validated configuration dictionary, used to set default training parameters based on the algorithm configuration. This is expected to be passed by Pydantic when calling the default constructor.

Returns:

  • TrainingConfig

    Training configuration with the specified parameters.