Training Configuration
Training configuration.
SelfSupervisedCheckpointing
dataclass
Presets for checkpointing Noise2Noise and Noise2Void.
Because self-supervised algorithms are evaluating the loss against noisy pixels, its value is not a good measure of performances after a few epochs. Therefore, it cannot be used to evaluate the best performing models.
This presets saves checkpoints every 10 epochs, as well as the last one.
auto_insert_metric_name = False
class-attribute
instance-attribute
Do not insert the monitored value in the checkpoint name.
every_n_epochs = 10
class-attribute
instance-attribute
Save a checkpoint every 10 epochs.
save_last = True
class-attribute
instance-attribute
Save the last checkpoint.
save_top_k = -1
class-attribute
instance-attribute
Save all checkpoints. Checkpoints are checked every_n_epochs.
SupervisedCheckpointing
dataclass
Presets for checkpointing CARE.
This preset saves the top 3 best performing checkpoints based on val_loss, as well
as the last one.
auto_insert_metric_name = False
class-attribute
instance-attribute
Do not insert the monitored value in the checkpoint name.
mode = 'min'
class-attribute
instance-attribute
Top checkpoints are selected by minimum val_loss.
monitor = 'val_loss'
class-attribute
instance-attribute
Monitor val_loss.
save_last = True
class-attribute
instance-attribute
Save the last checkpoint.
save_top_k = 3
class-attribute
instance-attribute
Save the top 3 best performing checkpoints.
TrainingConfig
Bases: BaseModel
Parameters related to the training.
By default, checkpoint_params and early_stopping_params have presets based on
whether the algorithm is supervised (CARE) or not (Noise2Void and by extension
Noise2Noise). In the case of CARE, the top 3 checkpoints are saved based on
val_loss. For the self-supervised algorithms, checkpoints are saved every 10
epochs. In both cases, the last checkpoint is saved. Early stopping is disabled
for self-supervised algorithms.
Attributes:
-
trainer_params(dict) –Parameters passed to the PyTorch Lightning Trainer class.
-
logger(Literal['wandb', 'tensorboard'] | None) –Additional Logger to use during training. If None, no logger will be used. Note that the
CAREamistuses thecsvlogger regardless of the value of this field. -
checkpoint_params(dict[str, Any]) –Checkpoint callback parameters, following PyTorch Lightning Checkpoint callback.
-
early_stopping_params(dict[str, Any] | None) –Early stopping callback parameters, following PyTorch Lightning Checkpoint callback.
checkpoint_params = Field(default_factory=dict)
class-attribute
instance-attribute
Checkpoint callback parameters, following PyTorch Lightning Checkpoint callback.
early_stopping_params = Field(default_factory=dict)
class-attribute
instance-attribute
Early stopping callback parameters, following PyTorch Lightning EarlyStopping callback.
logger = None
class-attribute
instance-attribute
Logger to use during training. If None, no logger will be used. Available loggers are defined in SupportedLogger.
trainer_params = Field(default={})
class-attribute
instance-attribute
Parameters passed to the PyTorch Lightning Trainer class
__str__()
valid_ckpt_parameters(user_params)
classmethod
Validate parameters based on the checkpoint callback signature.
Parameters:
-
user_params(dict) –User parameters.
Returns:
-
dict–Validated checkpoint parameters.
Raises:
-
ValueError–If there are unknown parameters for the checkpoint callback.
valid_early_stopping_parameters(user_params)
classmethod
Validate parameters based on the early stopping callback signature.
Parameters:
-
user_params(dict) –User parameters.
Returns:
-
dict–Validated early stopping parameters.
Raises:
-
ValueError–If there are unknown parameters for the early stopping callback.
valid_trainer_parameters(user_params)
classmethod
Validate parameters based on the PyTorch Lightning Trainer signature.
Parameters:
-
user_params(dict) –User parameters.
Returns:
-
dict–Validated trainer parameters.
Raises:
-
ValueError–If there are unknown parameters for the PyTorch Lightning Trainer.
default_training_dict(algorithm, trainer_params=None, logger='none', checkpoint_params=None, early_stopping_params=None, monitor_metric='val_loss')
Default training configuration constructor.
This function sets default training parameters based on the algorithm configuration. If the user provides any of the parameters, they will take precedence over the defaults.
Parameters:
-
algorithm(('care', 'n2n', 'n2v'), default:"care") –Algorithm type, used to select the default checkpointing preset.
-
trainer_params(dict, default:None) –Parameters for Lightning Trainer class, by default None.
-
logger(('wandb', 'tensorboard', 'none'), default:"wandb") –Logger to use, by default "none".
-
checkpoint_params(dict, default:None) –Parameters for the checkpoint callback, by default None. If None, then default parameters are applied based on the algorithm.
-
early_stopping_params(dict, default:None) –Parameters for the early stopping callback, by default None. If None, then default parameters are applied based on the algorithm.
-
monitor_metric(str, default:'val_loss') –Metric to monitor for early stopping, by default "val_loss".
Returns:
-
dict–Training configuration dictionary with the specified parameters.
default_training_factory(validated_dict)
Default training configuration constructor.
Parameters:
-
validated_dict(dict) –Validated configuration dictionary, used to set default training parameters based on the algorithm configuration. This is expected to be passed by Pydantic when calling the default constructor.
Returns:
-
TrainingConfig–Training configuration with the specified parameters.