Skip to content

get_module

Factory functions for lightning modules.

create_module(algorithm_config) #

Initialize the correct Lightning module from an algorithm config.

Parameters:

Name Type Description Default
algorithm_config UNetBasedAlgorithm

The pydantic model with algorithm specific parameters.

required

Returns:

Type Description
CAREamicsModule

A lightning module for running one of the algorithms supported by CAREamics.

Raises:

Type Description
NotImplementedError

If the chosen algorithm is not yet supported.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/get_module.py
def create_module(algorithm_config: UNetBasedAlgorithm) -> CAREamicsModule:
    """
    Initialize the correct Lightning module from an algorithm config.

    Parameters
    ----------
    algorithm_config : UNetBasedAlgorithm
        The pydantic model with algorithm specific parameters.

    Returns
    -------
    CAREamicsModule
        A lightning module for running one of the algorithms supported by CAREamics.

    Raises
    ------
    NotImplementedError
        If the chosen algorithm is not yet supported.
    """
    if isinstance(algorithm_config, CAREAlgorithm):
        return CAREModule(algorithm_config)
    elif isinstance(algorithm_config, N2VAlgorithm):
        return N2VModule(algorithm_config)
    else:
        algorithm = algorithm_config.algorithm
        raise NotImplementedError(
            f"Support for {algorithm} has not been implemented yet."
        )

get_module_cls(algorithm) #

Get the lightning module class for the specified algorithm.

Parameters:

Name Type Description Default
algorithm SupportedAlgorithm

One of the algorithms supported by CAREamics, e.g. "n2v".

required

Returns:

Type Description
CAREamicsModuleCls

A Lightning module class for running the specified algorithm.

Raises:

Type Description
NotImplementedError

If the chosen algorithm is not get supported.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/get_module.py
def get_module_cls(algorithm: SupportedAlgorithm) -> CAREamicsModuleCls:
    """
    Get the lightning module class for the specified `algorithm`.

    Parameters
    ----------
    algorithm : SupportedAlgorithm
        One of the algorithms supported by CAREamics, e.g. `"n2v"`.

    Returns
    -------
    CAREamicsModuleCls
        A Lightning module class for running the specified `algorithm`.

    Raises
    ------
    NotImplementedError
        If the chosen algorithm is not get supported.
    """
    match algorithm:
        case SupportedAlgorithm.CARE:
            return CAREModule
        case SupportedAlgorithm.N2V:
            return N2VModule
        case _:
            raise NotImplementedError(
                f"Support for {algorithm.value} has not been implemented yet."
            )

load_module_from_checkpoint(checkpoint_path) #

Load a trained CAREamics module from checkpoint.

Automatically detects the algorithm type from the checkpoint and loads the appropriate module with trained weights.

Parameters:

Name Type Description Default
checkpoint_path Path

Path to the PyTorch Lightning checkpoint file.

required

Returns:

Type Description
CAREamicsModule

Lightning module with loaded weights.

Raises:

Type Description
ValueError

If the algorithm type cannot be determined from the checkpoint.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/get_module.py
def load_module_from_checkpoint(checkpoint_path: Path) -> CAREamicsModule:
    """
    Load a trained CAREamics module from checkpoint.

    Automatically detects the algorithm type from the checkpoint and loads
    the appropriate module with trained weights.

    Parameters
    ----------
    checkpoint_path : Path
        Path to the PyTorch Lightning checkpoint file.

    Returns
    -------
    CAREamicsModule
        Lightning module with loaded weights.

    Raises
    ------
    ValueError
        If the algorithm type cannot be determined from the checkpoint.
    """
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    try:
        algorithm = checkpoint["hyper_parameters"]["algorithm_config"]["algorithm"]
        algorithm = SupportedAlgorithm(algorithm)
    except (KeyError, ValueError) as e:
        raise ValueError(
            f"Could not determine algorithm type from checkpoint at: {checkpoint_path}"
        ) from e

    ModuleClass = get_module_cls(algorithm)
    return ModuleClass.load_from_checkpoint(checkpoint_path)