Skip to content

model_factory

Model factory.

Model creation factory functions.

model_factory(model_configuration) #

Deep learning model factory.

Supported models are defined in careamics.config.SupportedArchitecture.

Parameters:

Name Type Description Default
model_configuration Union[UNetModel, VAEModel]

Model configuration.

required

Returns:

Type Description
Module

Model class.

Raises:

Type Description
NotImplementedError

If the requested architecture is not implemented.

Source code in src/careamics/models/model_factory.py
def model_factory(
    model_configuration: Union[UNetModel, LVAEModel, CustomModel],
) -> torch.nn.Module:
    """
    Deep learning model factory.

    Supported models are defined in careamics.config.SupportedArchitecture.

    Parameters
    ----------
    model_configuration : Union[UNetModel, VAEModel]
        Model configuration.

    Returns
    -------
    torch.nn.Module
        Model class.

    Raises
    ------
    NotImplementedError
        If the requested architecture is not implemented.
    """
    if model_configuration.architecture == SupportedArchitecture.UNET:
        return UNet(**model_configuration.model_dump())
    elif model_configuration.architecture == SupportedArchitecture.LVAE:
        return LVAE(**model_configuration.model_dump())
    elif model_configuration.architecture == SupportedArchitecture.CUSTOM:
        assert isinstance(model_configuration, CustomModel)
        model = get_custom_model(model_configuration.name)
        return model(**model_configuration.model_dump())
    else:
        raise NotImplementedError(
            f"Model {model_configuration.architecture} is not implemented or unknown."
        )