Skip to content

load_checkpoint

Module for loading CAREamics models and configs from a checkpoint.

load_config_from_checkpoint(checkpoint_path) #

Load a CAREamics config from a checkpoint.

Some fields, if missing, will be populated by defaults. Namely, version, training_config and experiment_name.

The default for experiment_name will be "loaded_from_<checkpoint_filename>".

Parameters:

Name Type Description Default
checkpoint_path Path

Path to the PyTorch Lightning checkpoint file.

required

Returns:

Type Description
Configuration

A CAREamics configuration object.

Raises:

Type Description
ValueErrors:

If certain required information is not found in the checkpoint.

Source code in src/careamics/lightning/dataset_ng/load_checkpoint.py
def load_config_from_checkpoint(
    checkpoint_path: Path,
) -> NGConfiguration[AlgorithmConfig]:
    """
    Load a CAREamics config from a checkpoint.

    Some fields, if missing, will be populated by defaults. Namely, `version`,
    `training_config` and `experiment_name`.

    The default for `experiment_name` will be `"loaded_from_<checkpoint_filename>"`.

    Parameters
    ----------
    checkpoint_path : Path
        Path to the PyTorch Lightning checkpoint file.

    Returns
    -------
    Configuration
        A CAREamics configuration object.

    Raises
    ------
    ValueErrors:
        If certain required information is not found in the checkpoint.
    """
    checkpoint: dict[str, Any] = torch.load(checkpoint_path, map_location="cpu")

    # if careamics_info is not included (i.e. it was saved with the lightning API)
    # then version and training_config will be the default from the pydantic models.
    careamics_info = checkpoint.get(
        "careamics_info", {"experiment_name": _create_loaded_exp_name(checkpoint_path)}
    )

    # --- alg config
    hparams_key = "hyper_parameters"
    try:
        algorithm_config: dict[str, Any] = checkpoint[hparams_key]["algorithm_config"]
    except (KeyError, IndexError) as e:
        raise ValueError(
            "Could not determine a CAREamics supported algorithm from the provided "
            f"checkpoint at: {checkpoint_path}."
        ) from e

    # --- data config
    data_hparams_key = "datamodule_hyper_parameters"
    try:
        data_config: dict[str, Any] = checkpoint[data_hparams_key]["data_config"]
    except (KeyError, IndexError) as e:
        raise ValueError(
            "Could not determine the data configuration from the provided "
            f"checkpoint at: {checkpoint_path}."
        ) from e

    # NOTE: it is important for subclasses to appear first in the Union
    # type adapter will check each class until one fits
    type_adapter: TypeAdapter[NGConfiguration[AlgorithmConfig]] = TypeAdapter(
        N2VConfiguration | NGConfiguration
    )
    config = type_adapter.validate_python(
        {
            "algorithm_config": algorithm_config,
            "data_config": data_config,
            **careamics_info,
        }
    )
    return config

load_module_from_checkpoint(checkpoint_path) #

Load a trained CAREamics module from checkpoint.

Automatically detects the algorithm type from the checkpoint and loads the appropriate module with trained weights.

Parameters:

Name Type Description Default
checkpoint_path Path

Path to the PyTorch Lightning checkpoint file.

required

Returns:

Type Description
CAREamicsModule

Lightning module with loaded weights.

Raises:

Type Description
ValueError

If the algorithm type cannot be determined from the checkpoint.

Source code in src/careamics/lightning/dataset_ng/load_checkpoint.py
def load_module_from_checkpoint(checkpoint_path: Path) -> CAREamicsModule:
    """
    Load a trained CAREamics module from checkpoint.

    Automatically detects the algorithm type from the checkpoint and loads
    the appropriate module with trained weights.

    Parameters
    ----------
    checkpoint_path : Path
        Path to the PyTorch Lightning checkpoint file.

    Returns
    -------
    CAREamicsModule
        Lightning module with loaded weights.

    Raises
    ------
    ValueError
        If the algorithm type cannot be determined from the checkpoint.
    """
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    try:
        algorithm = checkpoint["hyper_parameters"]["algorithm_config"]["algorithm"]
        algorithm = SupportedAlgorithm(algorithm)
    except (KeyError, ValueError) as e:
        raise ValueError(
            f"Could not determine algorithm type from checkpoint at: {checkpoint_path}"
        ) from e

    ModuleClass = get_module_cls(algorithm)
    return ModuleClass.load_from_checkpoint(checkpoint_path)