Skip to content

write_tiles_zarr_strategy

Tile Zarr writing strategy.

WriteTilesZarr #

Zarr tile writer strategy.

This writer creates zarr files, groups and arrays as needed and writes tiles into the appropriate locations.

Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
class WriteTilesZarr:
    """Zarr tile writer strategy.

    This writer creates zarr files, groups and arrays as needed and writes tiles
    into the appropriate locations.
    """

    def __init__(self) -> None:
        """Constructor."""
        self.current_store: zarr.Group | None = None
        self.current_group: zarr.Group | None = None
        self.current_array: zarr.Array | None = None

    def _create_zarr(self, store: str | Path) -> None:
        """Create a new zarr storage.

        Parameters
        ----------
        store : str | Path
            Path to the zarr store.
        """
        if not Path(store).exists():
            self.current_store = zarr.create_group(store)
        else:
            open_store = zarr.open(store)

            if not isinstance(open_store, zarr.Group):
                raise RuntimeError(f"Zarr store at {store} is not a group.")

            self.current_store = open_store

        print(f"Store: {Path(store).absolute()}")

    def _create_group(self, group_path: str) -> None:
        """Create a new group in an existing zarr storage.

        Parameters
        ----------
        group_path : str
            Path to the group within the zarr store.

        Raises
        ------
        RuntimeError
            If the zarr store has not been initialized.
        """
        if self.current_store is None:
            raise RuntimeError("Zarr store not initialized.")

        if group_path not in self.current_store:
            self.current_group = self.current_store.create_group(group_path)
        else:
            current_group = self.current_store[group_path]
            if not isinstance(current_group, zarr.Group):
                raise RuntimeError(f"Zarr group at {group_path} is not a group.")

            self.current_group = current_group

    def _create_array(
        self,
        array_name: str,
        shape: Sequence[int],
        chunks: Sequence[int],
    ) -> None:
        """Create a new array in an existing zarr group.

        Parameters
        ----------
        array_name : str
            Name of the array within the zarr group.
        shape : Sequence[int]
            Shape of the array.
        chunks : Sequence[int]
            Chunk size for the array.

        Raises
        ------
        RuntimeError
            If the zarr group has not been initialized.
        """
        if self.current_group is None:
            raise RuntimeError("Zarr group not initialized.")

        if array_name not in self.current_group:

            shape = [i for i in shape if i != 1]

            if chunks == (1,):  # guard against the ImageRegionData default
                raise ValueError("Chunks cannot be (1,).")

            if len(shape) != len(chunks):
                raise ValueError(
                    f"Shape {shape} and chunks {chunks} have different lengths."
                )

            self.current_array = self.current_group.create_array(
                name=array_name, shape=shape, chunks=tuple(chunks), dtype=float32
            )
        else:
            current_array = self.current_group[array_name]
            if not isinstance(current_array, zarr.Array):
                raise RuntimeError(f"Zarr array at {array_name} is not an array.")
            self.current_array = current_array

    def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
        """Write cropped tile to zarr array.

        Parameters
        ----------
        dirpath : Path
            Path to directory to save predictions to.
        region : ImageRegionData
            Image region data containing tile information.
        """
        store_path, parent_path, array_name = decipher_zarr_uri(region.source)
        output_store_path = _add_output_key(dirpath, store_path)

        if (
            self.current_group is None
            or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
            != output_store_path
        ):
            self._create_zarr(output_store_path)

        if self.current_group is None or self.current_group.name != parent_path:
            self._create_group(parent_path)

        if self.current_array is None or self.current_array.basename != array_name:
            shape = region.data_shape
            chunks = region.chunks
            self._create_array(array_name, shape, chunks)

        tile_spec: TileSpecs = region.region_spec  # type: ignore[assignment]
        crop_coords = tile_spec["crop_coords"]
        crop_size = tile_spec["crop_size"]
        stitch_coords = tile_spec["stitch_coords"]

        crop_slices: tuple[builtins.ellipsis | slice, ...] = (
            ...,
            *[
                slice(start, start + length)
                for start, length in zip(crop_coords, crop_size, strict=True)
            ],
        )
        stitch_slices: tuple[builtins.ellipsis | slice, ...] = (
            ...,
            *[
                slice(start, start + length)
                for start, length in zip(stitch_coords, crop_size, strict=True)
            ],
        )

        if self.current_array is not None:
            self.current_array[stitch_slices] = region.data.squeeze()[crop_slices]
        else:
            raise RuntimeError("Zarr array not initialized.")

    def write_batch(
        self,
        dirpath: Path,
        predictions: list[ImageRegionData],
    ) -> None:
        """
        Write all tiles to a Zarr file.

        Parameters
        ----------
        dirpath : Path
            Path to directory to save predictions to.
        predictions : list[ImageRegionData]
            Decollated predictions.
        """
        for region in predictions:
            self.write_tile(dirpath, region)

__init__() #

Constructor.

Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def __init__(self) -> None:
    """Constructor."""
    self.current_store: zarr.Group | None = None
    self.current_group: zarr.Group | None = None
    self.current_array: zarr.Array | None = None

write_batch(dirpath, predictions) #

Write all tiles to a Zarr file.

Parameters:

Name Type Description Default
dirpath Path

Path to directory to save predictions to.

required
predictions list[ImageRegionData]

Decollated predictions.

required
Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def write_batch(
    self,
    dirpath: Path,
    predictions: list[ImageRegionData],
) -> None:
    """
    Write all tiles to a Zarr file.

    Parameters
    ----------
    dirpath : Path
        Path to directory to save predictions to.
    predictions : list[ImageRegionData]
        Decollated predictions.
    """
    for region in predictions:
        self.write_tile(dirpath, region)

write_tile(dirpath, region) #

Write cropped tile to zarr array.

Parameters:

Name Type Description Default
dirpath Path

Path to directory to save predictions to.

required
region ImageRegionData

Image region data containing tile information.

required
Source code in src/careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py
def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
    """Write cropped tile to zarr array.

    Parameters
    ----------
    dirpath : Path
        Path to directory to save predictions to.
    region : ImageRegionData
        Image region data containing tile information.
    """
    store_path, parent_path, array_name = decipher_zarr_uri(region.source)
    output_store_path = _add_output_key(dirpath, store_path)

    if (
        self.current_group is None
        or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
        != output_store_path
    ):
        self._create_zarr(output_store_path)

    if self.current_group is None or self.current_group.name != parent_path:
        self._create_group(parent_path)

    if self.current_array is None or self.current_array.basename != array_name:
        shape = region.data_shape
        chunks = region.chunks
        self._create_array(array_name, shape, chunks)

    tile_spec: TileSpecs = region.region_spec  # type: ignore[assignment]
    crop_coords = tile_spec["crop_coords"]
    crop_size = tile_spec["crop_size"]
    stitch_coords = tile_spec["stitch_coords"]

    crop_slices: tuple[builtins.ellipsis | slice, ...] = (
        ...,
        *[
            slice(start, start + length)
            for start, length in zip(crop_coords, crop_size, strict=True)
        ],
    )
    stitch_slices: tuple[builtins.ellipsis | slice, ...] = (
        ...,
        *[
            slice(start, start + length)
            for start, length in zip(stitch_coords, crop_size, strict=True)
        ],
    )

    if self.current_array is not None:
        self.current_array[stitch_slices] = region.data.squeeze()[crop_slices]
    else:
        raise RuntimeError("Zarr array not initialized.")