Skip to content

write_tiles_zarr_strategy

Tile Zarr writing strategy.

WriteTilesZarr #

Zarr tile writer strategy.

This writer creates zarr files, groups and arrays as needed and writes tiles into the appropriate locations.

Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
class WriteTilesZarr:
    """Zarr tile writer strategy.

    This writer creates zarr files, groups and arrays as needed and writes tiles
    into the appropriate locations.
    """

    def __init__(self) -> None:
        """Constructor."""
        self.current_store: zarr.Group | None = None
        self.current_group: zarr.Group | None = None
        self.current_array: zarr.Array | None = None

    def _create_zarr(self, store: str | Path) -> None:
        """Create a new zarr storage.

        Parameters
        ----------
        store : str | Path
            Path to the zarr store.
        """
        if not Path(store).exists():
            self.current_store = zarr.create_group(store)
        else:
            open_store = zarr.open(store)

            if not isinstance(open_store, zarr.Group):
                raise RuntimeError(f"Zarr store at {store} is not a group.")

            self.current_store = open_store

        print(f"Store: {Path(store).absolute()}")

    def _create_group(self, group_path: str) -> None:
        """Create a new group in an existing zarr storage.

        Parameters
        ----------
        group_path : str
            Path to the group within the zarr store.

        Raises
        ------
        RuntimeError
            If the zarr store has not been initialized.
        """
        if self.current_store is None:
            raise RuntimeError("Zarr store not initialized.")

        if group_path not in self.current_store:
            self.current_group = self.current_store.create_group(group_path)
        else:
            current_group = self.current_store[group_path]
            if not isinstance(current_group, zarr.Group):
                raise RuntimeError(f"Zarr group at {group_path} is not a group.")

            self.current_group = current_group

    def _create_array(
        self,
        array_name: str,
        axes: str,
        data_shape: Sequence[int],
        shards: tuple[int, ...] | None,
        chunks: tuple[int, ...] | None,
    ) -> None:
        """Create a new array in an existing zarr group.

        Parameters
        ----------
        array_name : str
            Name of the array within the zarr group.
        axes : str
            Axes string in SC(Z)YX format with original data order.
        data_shape : Sequence[int]
            Shape of the array.
        shards : tuple[int, ...] or None
            Shard size for the array.
        chunks : tuple[int, ...] or None
            Chunk size for the array.

        Raises
        ------
        RuntimeError
            If the zarr group has not been initialized.
        """
        if self.current_group is None:
            raise RuntimeError("Zarr group not initialized.")

        if array_name not in self.current_group:
            # get shape without non-existing axes (S or C)
            updated_shape = _update_data_shape(axes, data_shape)

            if chunks is not None and len(updated_shape) != len(chunks):
                raise ValueError(
                    f"Shape {updated_shape} and chunks {chunks} have different lengths."
                )

            if chunks is None:
                chunks = _auto_chunks(axes, data_shape)

            # TODO if we auto_chunks, we probably want to auto shards as well
            # there is shards="auto" in zarr, where array.target_shard_size_bytes
            # needs to be used (see zarr-python docs)
            if shards is not None and len(chunks) != len(shards):
                raise ValueError(
                    f"Chunks {chunks} and shards {shards} have different lengths."
                )

            self.current_array = self.current_group.create_array(
                name=array_name,
                shape=updated_shape,
                shards=shards,
                chunks=chunks,
                dtype=float32,
            )
        else:
            current_array = self.current_group[array_name]
            if not isinstance(current_array, zarr.Array):
                raise RuntimeError(f"Zarr array at {array_name} is not an array.")
            self.current_array = current_array

    def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
        """Write cropped tile to zarr array.

        Parameters
        ----------
        dirpath : Path
            Path to directory to save predictions to.
        region : ImageRegionData
            Image region data containing tile information.
        """
        if is_valid_uri(region.source):
            store_path, parent_path, array_name = decipher_zarr_uri(region.source)
            output_store_path = _add_output_key(dirpath, store_path)
        else:
            raise NotImplementedError(
                f"Invalid zarr URI: {region.source}. Currently, only predicting from "
                f"Zarr files is supported when writing Zarr tiles."
            )

        if (
            self.current_group is None
            or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
            != output_store_path
        ):
            self._create_zarr(output_store_path)

        if self.current_group is None or self.current_group.name != parent_path:
            self._create_group(parent_path)

        if self.current_array is None or self.current_array.basename != array_name:
            # data_shape, chunks and shards are in SC(Z)YX order since they are reshaped
            # in the zarr image stack loader
            # If the source is not a Zarr file, then chunks and shards will be `None`.
            shape = region.data_shape
            chunks: tuple[int, ...] | None = region.additional_metadata.get(
                "chunks", None
            )
            shards: tuple[int, ...] | None = region.additional_metadata.get(
                "shards", None
            )
            self._create_array(array_name, region.axes, shape, shards, chunks)

        assert is_tile_specs(region.region_spec)  # for mypy
        tile_spec: TileSpecs = region.region_spec
        crop_coords = tile_spec["crop_coords"]
        crop_size = tile_spec["crop_size"]
        stitch_coords = tile_spec["stitch_coords"]

        # compute sample slice
        sample_idx = tile_spec["sample_idx"]

        # TODO there is duplicated code in stitch_prediction
        crop_slices: tuple[builtins.ellipsis | slice | int, ...] = (
            ...,
            *[
                slice(start, start + length)
                for start, length in zip(crop_coords, crop_size, strict=True)
            ],
        )
        stitch_slices: tuple[builtins.ellipsis | slice | int, ...] = (
            ...,
            *[
                slice(start, start + length)
                for start, length in zip(stitch_coords, crop_size, strict=True)
            ],
        )

        if self.current_array is not None:
            # region.data has shape C(Z)YX, broadcast can fail with singleton dims
            crop = region.data[crop_slices]

            if region.data.shape[0] == 1 and "C" not in region.axes:
                # singleton C dim, need to remove it before writing
                # unless it was present in the original axes
                crop = crop[0]

            if "S" in region.axes:
                if "C" in region.axes:
                    stitch_slices = (sample_idx, *stitch_slices[0:])
                else:
                    stitch_slices = (sample_idx, *stitch_slices[1:])

            self.current_array[stitch_slices] = crop
        else:
            raise RuntimeError("Zarr array not initialized.")

    def write_batch(
        self,
        dirpath: Path,
        predictions: list[ImageRegionData],
    ) -> None:
        """
        Write all tiles to a Zarr file.

        Parameters
        ----------
        dirpath : Path
            Path to directory to save predictions to.
        predictions : list[ImageRegionData]
            Decollated predictions.
        """
        for region in predictions:
            self.write_tile(dirpath, region)

__init__() #

Constructor.

Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def __init__(self) -> None:
    """Constructor."""
    self.current_store: zarr.Group | None = None
    self.current_group: zarr.Group | None = None
    self.current_array: zarr.Array | None = None

write_batch(dirpath, predictions) #

Write all tiles to a Zarr file.

Parameters:

Name Type Description Default
dirpath Path

Path to directory to save predictions to.

required
predictions list[ImageRegionData]

Decollated predictions.

required
Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def write_batch(
    self,
    dirpath: Path,
    predictions: list[ImageRegionData],
) -> None:
    """
    Write all tiles to a Zarr file.

    Parameters
    ----------
    dirpath : Path
        Path to directory to save predictions to.
    predictions : list[ImageRegionData]
        Decollated predictions.
    """
    for region in predictions:
        self.write_tile(dirpath, region)

write_tile(dirpath, region) #

Write cropped tile to zarr array.

Parameters:

Name Type Description Default
dirpath Path

Path to directory to save predictions to.

required
region ImageRegionData

Image region data containing tile information.

required
Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
    """Write cropped tile to zarr array.

    Parameters
    ----------
    dirpath : Path
        Path to directory to save predictions to.
    region : ImageRegionData
        Image region data containing tile information.
    """
    if is_valid_uri(region.source):
        store_path, parent_path, array_name = decipher_zarr_uri(region.source)
        output_store_path = _add_output_key(dirpath, store_path)
    else:
        raise NotImplementedError(
            f"Invalid zarr URI: {region.source}. Currently, only predicting from "
            f"Zarr files is supported when writing Zarr tiles."
        )

    if (
        self.current_group is None
        or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
        != output_store_path
    ):
        self._create_zarr(output_store_path)

    if self.current_group is None or self.current_group.name != parent_path:
        self._create_group(parent_path)

    if self.current_array is None or self.current_array.basename != array_name:
        # data_shape, chunks and shards are in SC(Z)YX order since they are reshaped
        # in the zarr image stack loader
        # If the source is not a Zarr file, then chunks and shards will be `None`.
        shape = region.data_shape
        chunks: tuple[int, ...] | None = region.additional_metadata.get(
            "chunks", None
        )
        shards: tuple[int, ...] | None = region.additional_metadata.get(
            "shards", None
        )
        self._create_array(array_name, region.axes, shape, shards, chunks)

    assert is_tile_specs(region.region_spec)  # for mypy
    tile_spec: TileSpecs = region.region_spec
    crop_coords = tile_spec["crop_coords"]
    crop_size = tile_spec["crop_size"]
    stitch_coords = tile_spec["stitch_coords"]

    # compute sample slice
    sample_idx = tile_spec["sample_idx"]

    # TODO there is duplicated code in stitch_prediction
    crop_slices: tuple[builtins.ellipsis | slice | int, ...] = (
        ...,
        *[
            slice(start, start + length)
            for start, length in zip(crop_coords, crop_size, strict=True)
        ],
    )
    stitch_slices: tuple[builtins.ellipsis | slice | int, ...] = (
        ...,
        *[
            slice(start, start + length)
            for start, length in zip(stitch_coords, crop_size, strict=True)
        ],
    )

    if self.current_array is not None:
        # region.data has shape C(Z)YX, broadcast can fail with singleton dims
        crop = region.data[crop_slices]

        if region.data.shape[0] == 1 and "C" not in region.axes:
            # singleton C dim, need to remove it before writing
            # unless it was present in the original axes
            crop = crop[0]

        if "S" in region.axes:
            if "C" in region.axes:
                stitch_slices = (sample_idx, *stitch_slices[0:])
            else:
                stitch_slices = (sample_idx, *stitch_slices[1:])

        self.current_array[stitch_slices] = crop
    else:
        raise RuntimeError("Zarr array not initialized.")