Skip to content

module_utils

Training utilities for Lightning modules.

configure_optimizers(model, optimizer_name, optimizer_parameters, lr_scheduler_name, lr_scheduler_parameters, monitor='val_loss') #

Configure optimizer and learning rate scheduler.

Parameters:

Name Type Description Default
model Module

The model whose parameters will be optimized.

required
optimizer_name str

The name of the optimizer to use.

required
optimizer_parameters dict[str, Any]

Parameters to pass to the optimizer constructor.

required
lr_scheduler_name str

The name of the learning rate scheduler to use.

required
lr_scheduler_parameters dict[str, Any]

Parameters to pass to the learning rate scheduler constructor.

required
monitor str

The metric to monitor for the learning rate scheduler, by default "val_loss".

'val_loss'

Returns:

Type Description
dict[str, Any]

A dictionary containing the optimizer and learning rate scheduler configuration.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
def configure_optimizers(
    model: nn.Module,
    optimizer_name: str,
    optimizer_parameters: dict[str, Any],
    lr_scheduler_name: str,
    lr_scheduler_parameters: dict[str, Any],
    monitor: str = "val_loss",
) -> dict[str, Any]:
    """Configure optimizer and learning rate scheduler.

    Parameters
    ----------
    model : nn.Module
        The model whose parameters will be optimized.
    optimizer_name : str
        The name of the optimizer to use.
    optimizer_parameters : dict[str, Any]
        Parameters to pass to the optimizer constructor.
    lr_scheduler_name : str
        The name of the learning rate scheduler to use.
    lr_scheduler_parameters : dict[str, Any]
        Parameters to pass to the learning rate scheduler constructor.
    monitor : str, optional
        The metric to monitor for the learning rate scheduler, by default "val_loss".

    Returns
    -------
    dict[str, Any]
        A dictionary containing the optimizer and learning rate scheduler configuration.
    """
    optimizer_func = get_optimizer(optimizer_name)
    optimizer = optimizer_func(  # type: ignore[operator]
        model.parameters(), **optimizer_parameters
    )

    scheduler_func = get_scheduler(lr_scheduler_name)
    scheduler = scheduler_func(optimizer, **lr_scheduler_parameters)  # type: ignore[operator]

    return {
        "optimizer": optimizer,
        "lr_scheduler": scheduler,
        "monitor": monitor,
    }

load_best_checkpoint(module) #

Load the best checkpoint from the trainer's checkpoint callback.

Parameters:

Name Type Description Default
module LightningModule

The Lightning module to load the checkpoint into.

required

Returns:

Type Description
bool

True if checkpoint was loaded, False otherwise.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
def load_best_checkpoint(module: L.LightningModule) -> bool:
    """Load the best checkpoint from the trainer's checkpoint callback.

    Parameters
    ----------
    module : L.LightningModule
        The Lightning module to load the checkpoint into.

    Returns
    -------
    bool
        True if checkpoint was loaded, False otherwise.
    """
    if (
        not hasattr(module.trainer, "checkpoint_callback")
        or module.trainer.checkpoint_callback is None
    ):
        logger.warning("No checkpoint callback found, cannot load best checkpoint.")
        return False

    best_model_path = module.trainer.checkpoint_callback.best_model_path  # type: ignore[attr-defined]
    if best_model_path and best_model_path != "":
        logger.info(f"Loading best checkpoint from: {best_model_path}")
        model_state = torch.load(best_model_path, weights_only=True)["state_dict"]
        module.load_state_dict(model_state)
        return True
    else:
        logger.warning("No best checkpoint found.")
        return False

log_training_stats(module, loss, batch_size) #

Log training loss and learning rate.

Parameters:

Name Type Description Default
module LightningModule

The Lightning module to log stats for.

required
loss Any

The loss value for the current training step.

required
batch_size int

The size of the batch used in the current training step.

required
Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
def log_training_stats(module: L.LightningModule, loss: Any, batch_size: int) -> None:
    """Log training loss and learning rate.

    Parameters
    ----------
    module : L.LightningModule
        The Lightning module to log stats for.
    loss : Any
        The loss value for the current training step.
    batch_size : int
        The size of the batch used in the current training step.
    """
    module.log(
        "train_loss",
        loss,
        on_step=True,
        on_epoch=True,
        prog_bar=True,
        logger=True,
        batch_size=batch_size,
    )

    optimizer = module.optimizers()
    if isinstance(optimizer, list):
        current_lr = optimizer[0].param_groups[0]["lr"]
    else:
        current_lr = optimizer.param_groups[0]["lr"]
    module.log(
        "learning_rate",
        current_lr,
        on_step=False,
        on_epoch=True,
        logger=True,
        batch_size=batch_size,
    )

log_validation_stats(module, loss, batch_size, metrics) #

Log validation loss and metrics.

Parameters:

Name Type Description Default
module LightningModule

The Lightning module to log stats for.

required
loss Any

The loss value for the current validation step.

required
batch_size int

The size of the batch used in the current validation step.

required
metrics MetricCollection

The metrics collection to log.

required
Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
def log_validation_stats(
    module: L.LightningModule,
    loss: Any,
    batch_size: int,
    metrics: MetricCollection,
) -> None:
    """Log validation loss and metrics.

    Parameters
    ----------
    module : L.LightningModule
        The Lightning module to log stats for.
    loss : Any
        The loss value for the current validation step.
    batch_size : int
        The size of the batch used in the current validation step.
    metrics : MetricCollection
        The metrics collection to log.
    """
    module.log(
        "val_loss",
        loss,
        on_step=False,
        on_epoch=True,
        prog_bar=True,
        logger=True,
        batch_size=batch_size,
    )
    module.log_dict(metrics, on_step=False, on_epoch=True, batch_size=batch_size)