Skip to content

data_stats_callback

Data statistics callback.

DataStatsCallback #

Bases: Callback

Callback to update model's data statistics from datamodule.

This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.

Source code in src/careamics/lightning/callbacks/data_stats_callback.py
class DataStatsCallback(Callback):
    """Callback to update model's data statistics from datamodule.

    This callback ensures that the model has access to the data statistics (mean, std)
    calculated by the datamodule before training starts.
    """

    def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
        """Called when trainer is setting up.

        Parameters
        ----------
        trainer : Lightning.Trainer
            PyTorch Lightning trainer.
        module : Lightning.LightningModule
            Lightning module.
        stage : str
            Current stage (fit, validate, test, or predict).
        """
        if stage == "fit":
            # Get data statistics from datamodule
            (data_mean, data_std), _ = trainer.datamodule.get_data_stats()

            # Set data statistics in the model's likelihood module
            module.noise_model_likelihood.set_data_stats(
                data_mean=data_mean["target"], data_std=data_std["target"]
            )

setup(trainer, module, stage) #

Called when trainer is setting up.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
module LightningModule

Lightning module.

required
stage str

Current stage (fit, validate, test, or predict).

required
Source code in src/careamics/lightning/callbacks/data_stats_callback.py
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
    """Called when trainer is setting up.

    Parameters
    ----------
    trainer : Lightning.Trainer
        PyTorch Lightning trainer.
    module : Lightning.LightningModule
        Lightning module.
    stage : str
        Current stage (fit, validate, test, or predict).
    """
    if stage == "fit":
        # Get data statistics from datamodule
        (data_mean, data_std), _ = trainer.datamodule.get_data_stats()

        # Set data statistics in the model's likelihood module
        module.noise_model_likelihood.set_data_stats(
            data_mean=data_mean["target"], data_std=data_std["target"]
        )