Bases: Callback
Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.
Source code in src/careamics/lightning/callbacks/data_stats_callback.py
| class DataStatsCallback(Callback):
"""Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std)
calculated by the datamodule before training starts.
"""
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
"""Called when trainer is setting up.
Parameters
----------
trainer : Lightning.Trainer
PyTorch Lightning trainer.
module : Lightning.LightningModule
Lightning module.
stage : str
Current stage (fit, validate, test, or predict).
"""
if stage == "fit":
# Get data statistics from datamodule
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
# Set data statistics in the model's likelihood module
module.noise_model_likelihood.set_data_stats(
data_mean=data_mean["target"], data_std=data_std["target"]
)
|
setup(trainer, module, stage)
Called when trainer is setting up.
Parameters:
| Name | Type | Description | Default |
trainer | Trainer | PyTorch Lightning trainer. | required |
module | LightningModule | | required |
stage | str | Current stage (fit, validate, test, or predict). | required |
Source code in src/careamics/lightning/callbacks/data_stats_callback.py
| def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
"""Called when trainer is setting up.
Parameters
----------
trainer : Lightning.Trainer
PyTorch Lightning trainer.
module : Lightning.LightningModule
Lightning module.
stage : str
Current stage (fit, validate, test, or predict).
"""
if stage == "fit":
# Get data statistics from datamodule
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
# Set data statistics in the model's likelihood module
module.noise_model_likelihood.set_data_stats(
data_mean=data_mean["target"], data_std=data_std["target"]
)
|