Implementing an image stack
This is an advanced tutorial for creating a custom image stack. A custom image stack useful for extending CAREamics to train on not natively supported chunked or memory-mapped formats, when all the data cannot fit into memory. See the Image Stack & Loader Tutorial for a full example on how to train on custom memory-mapped data.
If the data can fit into memory, see the Custom Read Function Tutorial for a description of an alternative simpler mechanism for training on custom data that can be loaded into memory.
The ImageStack protocol provides an interface so that CAREamics can interact with data stored in different formats.
Required Attributes
Any class that implements the ImageStack protocol is required to provide the following attributes. They can be implemented as simple instance attributes or as properties.
source: str
The source is a string that will be passed through the prediction pipeline to identify the input that the prediction was produced from. Ideally, it should be unique for each image stack instance, usually a natural choice will be a path to the data.
data_shape: Sequence[int]
This is the the shape of the data after it has been transformed to the SC(Z)YX axes order. The AxesTransform class provides a convenient way to calculate this.
from careamics.utils.reshape_array import AxesTransform
original_axes = "YXC"
original_data_shape = (512, 620, 2)
data_shape = AxesTransform(original_axes, original_data_shape).transformed_shape
print(data_shape)
data_dtype: numpy.typing.DTypeLike
This is the data type of the data as it's equivalent NumPy representation.
original_data_shape: Sequence[int]
This is the original shape of the data, before any transformations.
original_axes: str
This is the original axes order of the data, before the transformation. The image stack should be initialized with an axes argument and save that value as an attribute.
Required Methods
extract_patch
The full signature can be seen in the API reference extract_patch.
The extract patch method needs to return a patch that is specified by the input parameters. The patch needs to be transformed to have C(Z)YX axes.
If the patch is out of bounds of the image, it should be padded with zeros, for the feature of predicting with a tile size that is larger than the image. This feature is not used during training.
Some useful utility functions are:
get_patch_slices: It returns NumPy-style slice objects in the original axis order.reshape_patch: It will transform the patch from its original axes order toC(Z)YX.pad_patch: For padding patches that are queried from outside of the image bounds.
Example implementation
All the natively supported image stack implementations can be found in the image_stack package.
This is an additional example for HDF5 data.
from collections.abc import Sequence
import h5py
from numpy.typing import NDArray, DTypeLike
from careamics.utils.reshape_array import reshape_patch, get_patch_slices, AxesTransform
from careamics.dataset.image_stack.image_utils import pad_patch
class HDF5ImageStack:
def __init__(self, image_data: h5py.Dataset, axes: str):
self._image_data = image_data
self.original_axes = axes # (1)!
self.original_data_shape = image_data.shape # (2)!
self.data_shape = AxesTransform( # (3)!
axes, self.original_data_shape
).transformed_shape
@property
def data_dtype(self) -> DTypeLike:
return self._image_data.dtype
@property
def source(self) -> str: # (4)!
return "#".join([self._image_data.file.filename, str(self._image_data.name)])
def extract_patch(
self,
sample_idx: int,
channels: Sequence[int] | None,
coords: Sequence[int],
patch_size: Sequence[int],
) -> NDArray:
"""Extract a patch for a given sample and channels within the image stack.
Parameters
----------
sample_idx : int
Sample index.
channels : sequence of int or None
Channel indices to extract. If `None`, all channels will be extracted.
coords : sequence of int
Spatial coordinates of the top-left corner of the patch.
patch_size : sequence of int
Size of the patch in each spatial dimension.
Returns
-------
numpy.ndarray
A patch of the image data from a particular sample with dimensions C(Z)YX.
"""
patch_slice = get_patch_slices(
self.original_axes,
self.original_data_shape,
sample_idx,
channels,
coords,
patch_size,
)
patch_data = self._image_data[patch_slice] # (5)!
patch_data = reshape_patch(patch_data, self.original_axes)
patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
return patch
original_axesis implemented as an instance attribute.original_data_shapeis implemented as an instance attribute.data_shapeis implemented as an instance attribute.- We decided to make the source the file path followed by a
#followed by the internal dataset path, e.g./data/hdf5_dataset.h5#/image_0. - HDF5 data can be sliced just like NumPy arrays.