class WriteTilesZarr:
"""Zarr tile writer strategy.
This writer creates zarr files, groups and arrays as needed and writes tiles
into the appropriate locations.
"""
def __init__(self) -> None:
"""Constructor."""
self.current_store: zarr.Group | None = None
self.current_group: zarr.Group | None = None
self.current_array: zarr.Array | None = None
def _create_zarr(self, store: str | Path) -> None:
"""Create a new zarr storage.
Parameters
----------
store : str | Path
Path to the zarr store.
"""
if not Path(store).exists():
self.current_store = zarr.create_group(store)
else:
open_store = zarr.open(store)
if not isinstance(open_store, zarr.Group):
raise RuntimeError(f"Zarr store at {store} is not a group.")
self.current_store = open_store
print(f"Store: {Path(store).absolute()}")
def _create_group(self, group_path: str) -> None:
"""Create a new group in an existing zarr storage.
Parameters
----------
group_path : str
Path to the group within the zarr store.
Raises
------
RuntimeError
If the zarr store has not been initialized.
"""
if self.current_store is None:
raise RuntimeError("Zarr store not initialized.")
if group_path not in self.current_store:
self.current_group = self.current_store.create_group(group_path)
else:
current_group = self.current_store[group_path]
if not isinstance(current_group, zarr.Group):
raise RuntimeError(f"Zarr group at {group_path} is not a group.")
self.current_group = current_group
def _create_array(
self,
array_name: str,
axes: str,
data_shape: Sequence[int],
shards: tuple[int, ...] | None,
chunks: tuple[int, ...] | None,
) -> None:
"""Create a new array in an existing zarr group.
Parameters
----------
array_name : str
Name of the array within the zarr group.
axes : str
Axes string in SC(Z)YX format with original data order.
data_shape : Sequence[int]
Shape of the array.
shards : tuple[int, ...] or None
Shard size for the array.
chunks : tuple[int, ...] or None
Chunk size for the array.
Raises
------
RuntimeError
If the zarr group has not been initialized.
"""
if self.current_group is None:
raise RuntimeError("Zarr group not initialized.")
if array_name not in self.current_group:
# get shape without non-existing axes (S or C)
updated_shape = _update_data_shape(axes, data_shape)
if chunks is not None and len(updated_shape) != len(chunks):
raise ValueError(
f"Shape {updated_shape} and chunks {chunks} have different lengths."
)
if chunks is None:
chunks = _auto_chunks(axes, data_shape)
# TODO if we auto_chunks, we probably want to auto shards as well
# there is shards="auto" in zarr, where array.target_shard_size_bytes
# needs to be used (see zarr-python docs)
if shards is not None and len(chunks) != len(shards):
raise ValueError(
f"Chunks {chunks} and shards {shards} have different lengths."
)
self.current_array = self.current_group.create_array(
name=array_name,
shape=updated_shape,
shards=shards,
chunks=chunks,
dtype=float32,
)
else:
current_array = self.current_group[array_name]
if not isinstance(current_array, zarr.Array):
raise RuntimeError(f"Zarr array at {array_name} is not an array.")
self.current_array = current_array
def write_tile(self, dirpath: Path, region: ImageRegionData) -> None:
"""Write cropped tile to zarr array.
Parameters
----------
dirpath : Path
Path to directory to save predictions to.
region : ImageRegionData
Image region data containing tile information.
"""
if is_valid_uri(region.source):
store_path, parent_path, array_name = decipher_zarr_uri(region.source)
output_store_path = _add_output_key(dirpath, store_path)
else:
raise NotImplementedError(
f"Invalid zarr URI: {region.source}. Currently, only predicting from "
f"Zarr files is supported when writing Zarr tiles."
)
if (
self.current_group is None
or str(self.current_group.store_path)[: len(OUTPUT_KEY)]
!= output_store_path
):
self._create_zarr(output_store_path)
if self.current_group is None or self.current_group.name != parent_path:
self._create_group(parent_path)
if self.current_array is None or self.current_array.basename != array_name:
# data_shape, chunks and shards are in SC(Z)YX order since they are reshaped
# in the zarr image stack loader
# If the source is not a Zarr file, then chunks and shards will be `None`.
shape = region.data_shape
chunks: tuple[int, ...] | None = region.additional_metadata.get(
"chunks", None
)
shards: tuple[int, ...] | None = region.additional_metadata.get(
"shards", None
)
self._create_array(array_name, region.axes, shape, shards, chunks)
assert is_tile_specs(region.region_spec) # for mypy
tile_spec: TileSpecs = region.region_spec
crop_coords = tile_spec["crop_coords"]
crop_size = tile_spec["crop_size"]
stitch_coords = tile_spec["stitch_coords"]
# compute sample slice
sample_idx = tile_spec["sample_idx"]
# TODO there is duplicated code in stitch_prediction
crop_slices: tuple[builtins.ellipsis | slice | int, ...] = (
...,
*[
slice(start, start + length)
for start, length in zip(crop_coords, crop_size, strict=True)
],
)
stitch_slices: tuple[builtins.ellipsis | slice | int, ...] = (
...,
*[
slice(start, start + length)
for start, length in zip(stitch_coords, crop_size, strict=True)
],
)
if self.current_array is not None:
# region.data has shape C(Z)YX, broadcast can fail with singleton dims
crop = region.data[crop_slices]
if region.data.shape[0] == 1 and "C" not in region.axes:
# singleton C dim, need to remove it before writing
# unless it was present in the original axes
crop = crop[0]
if "S" in region.axes:
if "C" in region.axes:
stitch_slices = (sample_idx, *stitch_slices[0:])
else:
stitch_slices = (sample_idx, *stitch_slices[1:])
self.current_array[stitch_slices] = crop
else:
raise RuntimeError("Zarr array not initialized.")
def write_batch(
self,
dirpath: Path,
predictions: list[ImageRegionData],
) -> None:
"""
Write all tiles to a Zarr file.
Parameters
----------
dirpath : Path
Path to directory to save predictions to.
predictions : list[ImageRegionData]
Decollated predictions.
"""
for region in predictions:
self.write_tile(dirpath, region)