Skip to content

restore_original_shape

Utility to restore predictions to original shape and dimension order.

restore_original_shape(array, original_axes, original_data_shape) #

Restore array to original shape and dimension order.

Parameters:

Name Type Description Default
array ndarray

Array in SC(Z)YX format.

required
original_axes str

Original axes order of the data.

required
original_data_shape Sequence[int]

Original shape of the data.

required

Returns:

Type Description
ndarray

Array reshaped to match axes and original_data_shape.

Source code in src/careamics/lightning/dataset_ng/prediction/restore_original_shape.py
def restore_original_shape(
    array: NDArray,
    original_axes: str,
    original_data_shape: Sequence[int],
) -> NDArray:
    """
    Restore array to original shape and dimension order.

    Parameters
    ----------
    array : numpy.ndarray
        Array in SC(Z)YX format.
    original_axes : str
        Original axes order of the data.
    original_data_shape : Sequence[int]
        Original shape of the data.

    Returns
    -------
    numpy.ndarray
        Array reshaped to match axes and original_data_shape.
    """
    if len(array.shape) not in (4, 5):
        raise ValueError(
            f"Expected array with 4 or 5 dimensions (SC(Z)YX), got {len(array.shape)}."
        )

    # current axes from array shape (S and C always present)
    current_axes = "SCZYX" if len(array.shape) == 5 else "SCYX"

    # handle special CZI case where T is used as Z
    if "T" in original_axes and "Z" not in original_axes and len(array.shape) == 5:
        original_axes = original_axes.replace("T", "Z")

    # unflatten S dimension
    merged_dims = [dim for dim in original_axes if dim not in current_axes]

    if merged_dims:
        unflattened_sizes = []
        unflattened_dims = []

        if "S" in original_axes:
            s_size = original_data_shape[original_axes.index("S")]
            unflattened_sizes.append(s_size)
            unflattened_dims.append("S")

        for dim in merged_dims:
            dim_size = original_data_shape[original_axes.index(dim)]
            unflattened_sizes.append(dim_size)
            unflattened_dims.append(dim)

        # replace S dimension with unflattened dimensions
        s_idx = current_axes.index("S")  # TODO always 0
        new_shape = list(array.shape)
        new_shape[s_idx : s_idx + 1] = unflattened_sizes
        array = array.reshape(new_shape)

        # update current axes
        current_axes = (
            current_axes[:s_idx] + "".join(unflattened_dims) + current_axes[s_idx + 1 :]
        )

    # remove singleton C if not in original axes
    if "C" not in original_axes:
        c_idx = current_axes.index("C")
        if array.shape[c_idx] == 1:
            array = np.squeeze(array, axis=c_idx)
            current_axes = current_axes[:c_idx] + current_axes[c_idx + 1 :]

    # same for singleton S
    if "S" in current_axes and "S" not in original_axes:
        s_idx = current_axes.index("S")
        if array.shape[s_idx] == 1:
            array = np.squeeze(array, axis=s_idx)
            current_axes = current_axes[:s_idx] + current_axes[s_idx + 1 :]

    # reorder to match original axes
    if current_axes != original_axes:
        source_order = [current_axes.index(axis) for axis in original_axes]
        target_order = list(range(len(original_axes)))
        array = np.moveaxis(array, source_order, target_order)

    return array