Deep learning model factory.
Supported models are defined in careamics.config.SupportedArchitecture.
Parameters:
Name | Type | Description | Default |
model_configuration | Union[UNetModel, VAEModel] | | required |
Returns:
Raises:
Type | Description |
NotImplementedError | If the requested architecture is not implemented. |
Source code in src/careamics/models/model_factory.py
| def model_factory(
model_configuration: Union[UNetModel, LVAEModel],
) -> torch.nn.Module:
"""
Deep learning model factory.
Supported models are defined in careamics.config.SupportedArchitecture.
Parameters
----------
model_configuration : Union[UNetModel, VAEModel]
Model configuration.
Returns
-------
torch.nn.Module
Model class.
Raises
------
NotImplementedError
If the requested architecture is not implemented.
"""
if model_configuration.architecture == SupportedArchitecture.UNET:
return UNet(**model_configuration.model_dump())
elif model_configuration.architecture == SupportedArchitecture.LVAE:
return LVAE(**model_configuration.model_dump())
else:
raise NotImplementedError(
f"Model {model_configuration.architecture} is not implemented or unknown."
)
|