Skip to content

index_manager

GridIndexManager dataclass #

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
@dataclass
class GridIndexManager:
    data_shape: tuple
    grid_shape: tuple
    patch_shape: tuple
    tiling_mode: TilingMode

    # Patch is centered on index in the grid, grid size not used in training,
    # used only during val / test, grid size controls the overlap of the patches
    # in training you only get random patches every time
    # For borders - just cropped the data, so it perfectly divisible

    def __post_init__(self):
        assert len(self.data_shape) == len(
            self.grid_shape
        ), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
        assert len(self.data_shape) == len(
            self.patch_shape
        ), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
        innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
        for dim, pad in enumerate(innerpad):
            if pad < 0:
                raise ValueError(
                    f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
                )
            if pad % 2 != 0:
                raise ValueError(
                    f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
                )

    def patch_offset(self):
        return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2

    def get_individual_dim_grid_count(self, dim: int):
        """
        Returns the number of the grid in the specified dimension, ignoring all other dimensions.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"

        if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
            return self.data_shape[dim]
        elif self.tiling_mode == TilingMode.PadBoundary:
            return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
        elif self.tiling_mode == TilingMode.ShiftBoundary:
            excess_size = self.patch_shape[dim] - self.grid_shape[dim]
            return int(
                np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
            )
        else:
            excess_size = self.patch_shape[dim] - self.grid_shape[dim]
            return int(
                np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
            )

    def total_grid_count(self):
        """
        Returns the total number of grids in the dataset.
        """
        return self.grid_count(0) * self.get_individual_dim_grid_count(0)

    def grid_count(self, dim: int):
        """
        Returns the total number of grids for one value in the specified dimension.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"
        if dim == len(self.data_shape) - 1:
            return 1

        return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)

    def get_grid_index(self, dim: int, coordinate: int):
        """
        Returns the index of the grid in the specified dimension.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"
        assert (
            coordinate < self.data_shape[dim]
        ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"

        if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
            return coordinate
        elif self.tiling_mode == TilingMode.PadBoundary:  # self.trim_boundary is False:
            return np.floor(coordinate / self.grid_shape[dim])
        elif self.tiling_mode == TilingMode.TrimBoundary:
            excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
            # can be <0 if coordinate is in [0,grid_shape[dim]]
            return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
        elif self.tiling_mode == TilingMode.ShiftBoundary:
            excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
            if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
                return self.get_individual_dim_grid_count(dim) - 1
            else:
                # can be <0 if coordinate is in [0,grid_shape[dim]]
                return max(
                    0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
                )

        else:
            raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")

    def dataset_idx_from_grid_idx(self, grid_idx: tuple):
        """
        Returns the index of the grid in the dataset.
        """
        assert len(grid_idx) == len(
            self.data_shape
        ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
        index = 0
        for dim in range(len(grid_idx)):
            index += grid_idx[dim] * self.grid_count(dim)
        return index

    def get_patch_location_from_dataset_idx(self, dataset_idx: int):
        """
        Returns the patch location of the grid in the dataset.
        """
        grid_location = self.get_location_from_dataset_idx(dataset_idx)
        offset = self.patch_offset()
        return tuple(np.array(grid_location) - np.array(offset))

    def get_dataset_idx_from_grid_location(self, location: tuple):
        assert len(location) == len(
            self.data_shape
        ), f"Location {location} must have the same dimension as data shape {self.data_shape}"
        grid_idx = [
            self.get_grid_index(dim, location[dim]) for dim in range(len(location))
        ]
        return self.dataset_idx_from_grid_idx(tuple(grid_idx))

    def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
        """
        Returns the grid-start coordinate of the grid in the specified dimension.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"
        assert dim_index < self.get_individual_dim_grid_count(
            dim
        ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"

        if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
            return dim_index
        elif self.tiling_mode == TilingMode.PadBoundary:
            return dim_index * self.grid_shape[dim]
        elif self.tiling_mode == TilingMode.TrimBoundary:
            excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
            return dim_index * self.grid_shape[dim] + excess_size
        elif self.tiling_mode == TilingMode.ShiftBoundary:
            excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
            if dim_index < self.get_individual_dim_grid_count(dim) - 1:
                return dim_index * self.grid_shape[dim] + excess_size
            else:
                # on boundary. grid should be placed such that the patch covers the entire data.
                return self.data_shape[dim] - self.grid_shape[dim] - excess_size
        else:
            raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")

    def get_location_from_dataset_idx(self, dataset_idx: int):
        """
        Returns the start location of the grid in the dataset.
        """
        grid_idx = []
        for dim in range(len(self.data_shape)):
            grid_idx.append(dataset_idx // self.grid_count(dim))
            dataset_idx = dataset_idx % self.grid_count(dim)
        location = [
            self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
            for dim in range(len(self.data_shape))
        ]
        return tuple(location)

    def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
        """
        Returns True if the grid is on the boundary in the specified dimension.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"

        if dim > 0:
            dataset_idx = dataset_idx % self.grid_count(dim - 1)

        dim_index = dataset_idx // self.grid_count(dim)
        if only_end:
            return dim_index == self.get_individual_dim_grid_count(dim) - 1

        return (
            dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
        )

    def next_grid_along_dim(self, dataset_idx: int, dim: int):
        """
        Returns the index of the grid in the specified dimension in the specified direction.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"
        new_idx = dataset_idx + self.grid_count(dim)
        if new_idx >= self.total_grid_count():
            return None
        return new_idx

    def prev_grid_along_dim(self, dataset_idx: int, dim: int):
        """
        Returns the index of the grid in the specified dimension in the specified direction.
        """
        assert dim < len(
            self.data_shape
        ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
        assert dim >= 0, "Dimension must be greater than or equal to 0"
        new_idx = dataset_idx - self.grid_count(dim)
        if new_idx < 0:
            return None

dataset_idx_from_grid_idx(grid_idx) #

Returns the index of the grid in the dataset.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def dataset_idx_from_grid_idx(self, grid_idx: tuple):
    """
    Returns the index of the grid in the dataset.
    """
    assert len(grid_idx) == len(
        self.data_shape
    ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
    index = 0
    for dim in range(len(grid_idx)):
        index += grid_idx[dim] * self.grid_count(dim)
    return index

get_grid_index(dim, coordinate) #

Returns the index of the grid in the specified dimension.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def get_grid_index(self, dim: int, coordinate: int):
    """
    Returns the index of the grid in the specified dimension.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"
    assert (
        coordinate < self.data_shape[dim]
    ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"

    if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
        return coordinate
    elif self.tiling_mode == TilingMode.PadBoundary:  # self.trim_boundary is False:
        return np.floor(coordinate / self.grid_shape[dim])
    elif self.tiling_mode == TilingMode.TrimBoundary:
        excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
        # can be <0 if coordinate is in [0,grid_shape[dim]]
        return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
    elif self.tiling_mode == TilingMode.ShiftBoundary:
        excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
        if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
            return self.get_individual_dim_grid_count(dim) - 1
        else:
            # can be <0 if coordinate is in [0,grid_shape[dim]]
            return max(
                0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
            )

    else:
        raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")

get_gridstart_location_from_dim_index(dim, dim_index) #

Returns the grid-start coordinate of the grid in the specified dimension.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
    """
    Returns the grid-start coordinate of the grid in the specified dimension.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"
    assert dim_index < self.get_individual_dim_grid_count(
        dim
    ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"

    if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
        return dim_index
    elif self.tiling_mode == TilingMode.PadBoundary:
        return dim_index * self.grid_shape[dim]
    elif self.tiling_mode == TilingMode.TrimBoundary:
        excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
        return dim_index * self.grid_shape[dim] + excess_size
    elif self.tiling_mode == TilingMode.ShiftBoundary:
        excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
        if dim_index < self.get_individual_dim_grid_count(dim) - 1:
            return dim_index * self.grid_shape[dim] + excess_size
        else:
            # on boundary. grid should be placed such that the patch covers the entire data.
            return self.data_shape[dim] - self.grid_shape[dim] - excess_size
    else:
        raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")

get_individual_dim_grid_count(dim) #

Returns the number of the grid in the specified dimension, ignoring all other dimensions.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def get_individual_dim_grid_count(self, dim: int):
    """
    Returns the number of the grid in the specified dimension, ignoring all other dimensions.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"

    if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
        return self.data_shape[dim]
    elif self.tiling_mode == TilingMode.PadBoundary:
        return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
    elif self.tiling_mode == TilingMode.ShiftBoundary:
        excess_size = self.patch_shape[dim] - self.grid_shape[dim]
        return int(
            np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
        )
    else:
        excess_size = self.patch_shape[dim] - self.grid_shape[dim]
        return int(
            np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
        )

get_location_from_dataset_idx(dataset_idx) #

Returns the start location of the grid in the dataset.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def get_location_from_dataset_idx(self, dataset_idx: int):
    """
    Returns the start location of the grid in the dataset.
    """
    grid_idx = []
    for dim in range(len(self.data_shape)):
        grid_idx.append(dataset_idx // self.grid_count(dim))
        dataset_idx = dataset_idx % self.grid_count(dim)
    location = [
        self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
        for dim in range(len(self.data_shape))
    ]
    return tuple(location)

get_patch_location_from_dataset_idx(dataset_idx) #

Returns the patch location of the grid in the dataset.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def get_patch_location_from_dataset_idx(self, dataset_idx: int):
    """
    Returns the patch location of the grid in the dataset.
    """
    grid_location = self.get_location_from_dataset_idx(dataset_idx)
    offset = self.patch_offset()
    return tuple(np.array(grid_location) - np.array(offset))

grid_count(dim) #

Returns the total number of grids for one value in the specified dimension.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def grid_count(self, dim: int):
    """
    Returns the total number of grids for one value in the specified dimension.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"
    if dim == len(self.data_shape) - 1:
        return 1

    return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)

next_grid_along_dim(dataset_idx, dim) #

Returns the index of the grid in the specified dimension in the specified direction.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def next_grid_along_dim(self, dataset_idx: int, dim: int):
    """
    Returns the index of the grid in the specified dimension in the specified direction.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"
    new_idx = dataset_idx + self.grid_count(dim)
    if new_idx >= self.total_grid_count():
        return None
    return new_idx

on_boundary(dataset_idx, dim, only_end=False) #

Returns True if the grid is on the boundary in the specified dimension.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
    """
    Returns True if the grid is on the boundary in the specified dimension.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"

    if dim > 0:
        dataset_idx = dataset_idx % self.grid_count(dim - 1)

    dim_index = dataset_idx // self.grid_count(dim)
    if only_end:
        return dim_index == self.get_individual_dim_grid_count(dim) - 1

    return (
        dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
    )

prev_grid_along_dim(dataset_idx, dim) #

Returns the index of the grid in the specified dimension in the specified direction.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
    """
    Returns the index of the grid in the specified dimension in the specified direction.
    """
    assert dim < len(
        self.data_shape
    ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
    assert dim >= 0, "Dimension must be greater than or equal to 0"
    new_idx = dataset_idx - self.grid_count(dim)
    if new_idx < 0:
        return None

total_grid_count() #

Returns the total number of grids in the dataset.

Source code in src/careamics/lvae_training/dataset/utils/index_manager.py
def total_grid_count(self):
    """
    Returns the total number of grids in the dataset.
    """
    return self.grid_count(0) * self.get_individual_dim_grid_count(0)