Skip to content

model_description

Module use to build BMZ model description.

create_model_description(config, name, general_description, data_description, authors, inputs, outputs, weights_path, torch_version, careamics_version, config_path, env_path, covers, channel_names=None, model_version='0.1.0') #

Create model description.

Parameters:

Name Type Description Default
config Configuration

CAREamics configuration.

required
name str

Name of the model.

required
general_description str

General description of the model.

required
data_description str

Description of the data the model was trained on.

required
authors list[Author]

Authors of the model.

required
inputs Union[Path, str]

Path to input .npy file.

required
outputs Union[Path, str]

Path to output .npy file.

required
weights_path Union[Path, str]

Path to model weights.

required
torch_version str

Pytorch version.

required
careamics_version str

CAREamics version.

required
config_path Union[Path, str]

Path to model configuration.

required
env_path Union[Path, str]

Path to environment file.

required
covers list of pathlib.Path or str

Paths to cover images.

required
channel_names Optional[list[str]]

Channel names, by default None.

None
model_version str

Model version.

"0.1.0"

Returns:

Type Description
ModelDescr

Model description.

Source code in src/careamics/model_io/bioimage/model_description.py
def create_model_description(
    config: Configuration,
    name: str,
    general_description: str,
    data_description: str,
    authors: list[Author],
    inputs: Union[Path, str],
    outputs: Union[Path, str],
    weights_path: Union[Path, str],
    torch_version: str,
    careamics_version: str,
    config_path: Union[Path, str],
    env_path: Union[Path, str],
    covers: list[Union[Path, str]],
    channel_names: Optional[list[str]] = None,
    model_version: str = "0.1.0",
) -> ModelDescr:
    """Create model description.

    Parameters
    ----------
    config : Configuration
        CAREamics configuration.
    name : str
        Name of the model.
    general_description : str
        General description of the model.
    data_description : str
        Description of the data the model was trained on.
    authors : list[Author]
        Authors of the model.
    inputs : Union[Path, str]
        Path to input .npy file.
    outputs : Union[Path, str]
        Path to output .npy file.
    weights_path : Union[Path, str]
        Path to model weights.
    torch_version : str
        Pytorch version.
    careamics_version : str
        CAREamics version.
    config_path : Union[Path, str]
        Path to model configuration.
    env_path : Union[Path, str]
        Path to environment file.
    covers : list of pathlib.Path or str
        Paths to cover images.
    channel_names : Optional[list[str]], optional
        Channel names, by default None.
    model_version : str, default "0.1.0"
        Model version.

    Returns
    -------
    ModelDescr
        Model description.
    """
    # documentation
    doc = readme_factory(
        config,
        careamics_version=careamics_version,
        data_description=data_description,
    )

    # inputs, outputs
    input_descr, output_descr = _create_inputs_ouputs(
        input_array=np.load(inputs),
        output_array=np.load(outputs),
        data_config=config.data_config,
        input_path=inputs,
        output_path=outputs,
        channel_names=channel_names,
    )

    # weights description
    architecture_descr = ArchitectureFromLibraryDescr(
        import_from="careamics.models.unet",
        callable=f"{config.algorithm_config.model.architecture}",
        kwargs=config.algorithm_config.model.model_dump(),
    )

    weights_descr = WeightsDescr(
        pytorch_state_dict=PytorchStateDictWeightsDescr(
            source=weights_path,
            architecture=architecture_descr,
            pytorch_version=Version(torch_version),
            dependencies=EnvironmentFileDescr(source=env_path),
        ),
    )

    # overall model description
    model = ModelDescr(
        name=name,
        authors=authors,
        description=general_description,
        documentation=doc,
        inputs=[input_descr],
        outputs=[output_descr],
        tags=config.get_algorithm_keywords(),
        links=[
            "https://github.com/CAREamics/careamics",
            "https://careamics.github.io/latest/",
        ],
        license="BSD-3-Clause",
        config={
            "bioimageio": {
                "test_kwargs": {
                    "pytorch_state_dict": {
                        "absolute_tolerance": 1e-2,
                        "relative_tolerance": 1e-2,
                    }
                }
            }
        },
        version=model_version,
        weights=weights_descr,
        attachments=[FileDescr(source=config_path)],
        cite=config.get_algorithm_citations(),
        covers=covers,
    )

    return model

extract_model_path(model_desc) #

Return the relative path to the weights and configuration files.

Parameters:

Name Type Description Default
model_desc ModelDescr

Model description.

required

Returns:

Type Description
tuple of (path, path)

Weights and configuration paths.

Source code in src/careamics/model_io/bioimage/model_description.py
def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
    """Return the relative path to the weights and configuration files.

    Parameters
    ----------
    model_desc : ModelDescr
        Model description.

    Returns
    -------
    tuple of (path, path)
        Weights and configuration paths.
    """
    if model_desc.weights.pytorch_state_dict is None:
        raise ValueError("No model weights found in model description.")
    weights_path = resolve_and_extract(
        model_desc.weights.pytorch_state_dict.source
    ).path

    for file in model_desc.attachments:
        file_path = file.source if isinstance(file.source, Path) else file.source.path
        if file_path is None:
            continue
        file_path = Path(file_path)
        if file_path.name == "careamics.yaml":
            config_path = resolve_and_extract(file.source).path
            break
    else:
        raise ValueError("Configuration file not found.")

    return weights_path, config_path