Skip to content

data_factory

Convenience functions to create NG data configurations.

create_ng_data_configuration(data_type, axes, patch_size, batch_size, augmentations=None, normalization=None, channels=None, in_memory=None, num_workers=0, train_dataloader_params=None, val_dataloader_params=None, pred_dataloader_params=None, seed=None) #

Create a training NGDatasetConfig.

Note that num_workers is applied to all dataloaders unless explicitly overridden in the respective dataloader parameters.

Parameters:

Name Type Description Default
data_type (array, tiff, zarr, czi, custom)

Type of the data.

"array"
axes str

Axes of the data.

required
patch_size list of int

Size of the patches along the spatial dimensions.

required
batch_size int

Batch size.

required
augmentations list of transforms or None

List of transforms to apply. If None, default augmentations are applied (flip in X and Y, rotations by 90 degrees in the XY plane).

None
normalization dict

Normalization configuration dictionary. If None, defaults to mean_std normalization with automatically computed statistics.

None
channels Sequence of int

List of channels to use. If None, all channels are used.

None
in_memory bool

Whether to load all data into memory. This is only supported for 'array', 'tiff' and 'custom' data types. If None, defaults to True for 'array', 'tiff' and custom, and False for 'zarr' and 'czi' data types. Must be True for array.

None
num_workers int

Number of workers for data loading.

0
augmentations list of transforms or None

List of transforms to apply. If None, default augmentations are applied (flip in X and Y, rotations by 90 degrees in the XY plane).

None
train_dataloader_params dict

Parameters for the training dataloader, see PyTorch notes, by default None.

None
val_dataloader_params dict

Parameters for the validation dataloader, see PyTorch notes, by default None.

None
pred_dataloader_params dict

Parameters for the test dataloader, see PyTorch notes, by default None.

None
seed int

Random seed for reproducibility. If None, seed is generated automatically.

None

Returns:

Type Description
NGDataConfig

Next-Generation Data model with the specified parameters.

Source code in src/careamics/config/ng_factories/data_factory.py
def create_ng_data_configuration(
    data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
    axes: str,
    patch_size: Sequence[int],
    batch_size: int,
    augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
    normalization: dict | None = None,
    channels: Sequence[int] | None = None,
    in_memory: bool | None = None,
    num_workers: int = 0,
    train_dataloader_params: dict[str, Any] | None = None,
    val_dataloader_params: dict[str, Any] | None = None,
    pred_dataloader_params: dict[str, Any] | None = None,
    seed: int | None = None,
) -> NGDataConfig:
    """
    Create a training NGDatasetConfig.

    Note that `num_workers` is applied to all dataloaders unless explicitly overridden
    in the respective dataloader parameters.

    Parameters
    ----------
    data_type : {"array", "tiff", "zarr", "czi", "custom"}
        Type of the data.
    axes : str
        Axes of the data.
    patch_size : list of int
        Size of the patches along the spatial dimensions.
    batch_size : int
        Batch size.
    augmentations : list of transforms or None, default=None
        List of transforms to apply. If `None`, default augmentations are applied
        (flip in X and Y, rotations by 90 degrees in the XY plane).
    normalization : dict, default=None
        Normalization configuration dictionary. If None, defaults to mean_std
        normalization with automatically computed statistics.
    channels : Sequence of int, default=None
        List of channels to use. If `None`, all channels are used.
    in_memory : bool, default=None
        Whether to load all data into memory. This is only supported for 'array',
        'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
        'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
        for `array`.
    num_workers : int, default=0
        Number of workers for data loading.
    augmentations : list of transforms or None, default=None
        List of transforms to apply. If `None`, default augmentations are applied
        (flip in X and Y, rotations by 90 degrees in the XY plane).
    train_dataloader_params : dict
        Parameters for the training dataloader, see PyTorch notes, by default None.
    val_dataloader_params : dict
        Parameters for the validation dataloader, see PyTorch notes, by default None.
    pred_dataloader_params : dict
        Parameters for the test dataloader, see PyTorch notes, by default None.
    seed : int, default=None
        Random seed for reproducibility. If `None`, seed is generated automatically.

    Returns
    -------
    NGDataConfig
        Next-Generation Data model with the specified parameters.
    """
    if seed is None:
        seed = generate_random_seed()

    if augmentations is None:
        augmentations = list_spatial_augmentations(seed=seed)

    # data model
    data: dict[str, Any] = {
        "mode": "training",
        "data_type": data_type,
        "axes": axes,
        "batch_size": batch_size,
        "channels": channels,
        "transforms": augmentations,
        "seed": seed,
        "normalization": (
            normalization if normalization is not None else {"name": "mean_std"}
        ),
    }

    if in_memory is not None:
        data["in_memory"] = in_memory

    if train_dataloader_params is not None:
        # the presence of `shuffle` key in the dataloader parameters is enforced
        # by the NGDataConfig class
        if "shuffle" not in train_dataloader_params:
            train_dataloader_params["shuffle"] = True

        if "num_workers" not in train_dataloader_params:
            train_dataloader_params["num_workers"] = num_workers

        data["train_dataloader_params"] = train_dataloader_params
    else:
        data["train_dataloader_params"] = {"shuffle": True, "num_workers": num_workers}

    if val_dataloader_params is not None:
        if "num_workers" not in val_dataloader_params:
            val_dataloader_params["num_workers"] = num_workers

        data["val_dataloader_params"] = val_dataloader_params
    else:
        data["val_dataloader_params"] = {"shuffle": False, "num_workers": num_workers}

    if pred_dataloader_params is not None:
        if "num_workers" not in pred_dataloader_params:
            pred_dataloader_params["num_workers"] = num_workers

        data["pred_dataloader_params"] = pred_dataloader_params
    else:
        data["pred_dataloader_params"] = {"shuffle": False, "num_workers": num_workers}

    # add training patching
    data["patching"] = {
        "name": "random",
        "patch_size": patch_size,
    }

    return NGDataConfig(**data)

list_spatial_augmentations(augmentations=None, seed=None) #

List the augmentations to apply.

Parameters:

Name Type Description Default
augmentations list of transforms

List of transforms to apply, either both or one of XYFlipConfig and XYRandomRotate90Config.

None
seed int

Random seed for reproducibility.

None

Returns:

Type Description
list of transforms

List of transforms to apply.

Raises:

Type Description
ValueError

If the transforms are not XYFlipConfig or XYRandomRotate90Config.

ValueError

If there are duplicate transforms.

Source code in src/careamics/config/ng_factories/data_factory.py
def list_spatial_augmentations(
    augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
    seed: int | None = None,
) -> list[SPATIAL_TRANSFORMS_UNION]:
    """
    List the augmentations to apply.

    Parameters
    ----------
    augmentations : list of transforms, optional
        List of transforms to apply, either both or one of XYFlipConfig and
        XYRandomRotate90Config.
    seed : int, optional
        Random seed for reproducibility.

    Returns
    -------
    list of transforms
        List of transforms to apply.

    Raises
    ------
    ValueError
        If the transforms are not XYFlipConfig or XYRandomRotate90Config.
    ValueError
        If there are duplicate transforms.
    """
    if augmentations is None:
        transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
            XYFlipConfig(seed=seed),
            XYRandomRotate90Config(seed=seed),
        ]
    else:
        # throw error if not all transforms are pydantic models
        if not all(
            isinstance(t, XYFlipConfig) or isinstance(t, XYRandomRotate90Config)
            for t in augmentations
        ):
            raise ValueError(
                "Accepted transforms are either XYFlipConfig or "
                "XYRandomRotate90Config."
            )

        # check that there is no duplication
        aug_types = [t.__class__ for t in augmentations]
        if len(set(aug_types)) != len(aug_types):
            raise ValueError("Duplicate transforms are not allowed.")

        transform_list = augmentations

    return transform_list