module_utils
Training utilities for Lightning modules.
configure_optimizers(model, optimizer_name, optimizer_parameters, lr_scheduler_name, lr_scheduler_parameters, monitor='val_loss') #
Configure optimizer and learning rate scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | The model whose parameters will be optimized. | required |
optimizer_name | str | The name of the optimizer to use. | required |
optimizer_parameters | dict[str, Any] | Parameters to pass to the optimizer constructor. | required |
lr_scheduler_name | str | The name of the learning rate scheduler to use. | required |
lr_scheduler_parameters | dict[str, Any] | Parameters to pass to the learning rate scheduler constructor. | required |
monitor | str | The metric to monitor for the learning rate scheduler, by default "val_loss". | 'val_loss' |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | A dictionary containing the optimizer and learning rate scheduler configuration. |
Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
load_best_checkpoint(module) #
Load the best checkpoint from the trainer's checkpoint callback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | LightningModule | The Lightning module to load the checkpoint into. | required |
Returns:
| Type | Description |
|---|---|
bool | True if checkpoint was loaded, False otherwise. |
Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
log_training_stats(module, loss, batch_size) #
Log training loss and learning rate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | LightningModule | The Lightning module to log stats for. | required |
loss | Any | The loss value for the current training step. | required |
batch_size | int | The size of the batch used in the current training step. | required |
Source code in src/careamics/lightning/dataset_ng/lightning_modules/module_utils.py
log_validation_stats(module, loss, batch_size, metrics) #
Log validation loss and metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | LightningModule | The Lightning module to log stats for. | required |
loss | Any | The loss value for the current validation step. | required |
batch_size | int | The size of the batch used in the current validation step. | required |
metrics | MetricCollection | The metrics collection to log. | required |