Skip to content

Module Utils

Source

Training utilities for Lightning modules.

configure_optimizers(model, optimizer_name, optimizer_parameters, lr_scheduler_name, lr_scheduler_parameters, monitor='val_loss')

Configure optimizer and learning rate scheduler.

Parameters:

  • model (Module) –

    The model whose parameters will be optimized.

  • optimizer_name (str) –

    The name of the optimizer to use.

  • optimizer_parameters (dict[str, Any]) –

    Parameters to pass to the optimizer constructor.

  • lr_scheduler_name (str) –

    The name of the learning rate scheduler to use.

  • lr_scheduler_parameters (dict[str, Any]) –

    Parameters to pass to the learning rate scheduler constructor.

  • monitor (str, default: 'val_loss' ) –

    The metric to monitor for the learning rate scheduler, by default "val_loss".

Returns:

  • dict[str, Any]

    A dictionary containing the optimizer and learning rate scheduler configuration.

get_optimizer(name)

Return the optimizer class given its name.

Parameters:

  • name (str) –

    Optimizer name.

Returns:

  • Optimizer

    Optimizer class.

get_scheduler(name)

Return the scheduler class given its name.

Parameters:

  • name (str) –

    Scheduler name.

Returns:

  • Union

    Scheduler class.

log_training_stats(module, loss, batch_size)

Log training loss and learning rate.

Parameters:

  • module (LightningModule) –

    The Lightning module to log stats for.

  • loss (Any) –

    The loss value for the current training step.

  • batch_size (int) –

    The size of the batch used in the current training step.

log_validation_stats(module, loss, batch_size, metrics)

Log validation loss and metrics.

Parameters:

  • module (LightningModule) –

    The Lightning module to log stats for.

  • loss (Any) –

    The loss value for the current validation step.

  • batch_size (int) –

    The size of the batch used in the current validation step.

  • metrics (MetricCollection) –

    The metrics collection to log.