Skip to content

pixel_manipulation_torch

N2V manipulation functions for PyTorch.

median_manipulate_torch(batch, mask_pixel_percentage, subpatch_size=11, struct_params=None, rng=None) #

Manipulate pixels by replacing them with the median of their surrounding subpatch.

N2V2 version, manipulated pixels are selected randomly away from a grid with an approximate uniform probability to be selected across the whole patch.

If struct_params is not None, an additional structN2V mask is applied to the data, replacing the pixels in the mask with random values (excluding the pixel already manipulated).

Parameters:

Name Type Description Default
batch Tensor

Image patch, 2D or 3D, shape (y, x) or (z, y, x).

required
mask_pixel_percentage float

Approximate percentage of pixels to be masked.

required
subpatch_size int

Size of the subpatch the new pixel value is sampled from, by default 11.

11
struct_params StructMaskParameters or None

Parameters for the structN2V mask (axis and span).

None
rng default_generator or None

Random number generator, by default None.

None

Returns:

Type Description
tuple[Tensor, Tensor, Tensor]

tuple containing the manipulated patch, the original patch and the mask.

Source code in src/careamics/transforms/pixel_manipulation_torch.py
def median_manipulate_torch(
    batch: torch.Tensor,
    mask_pixel_percentage: float,
    subpatch_size: int = 11,
    struct_params: StructMaskParameters | None = None,
    rng: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Manipulate pixels by replacing them with the median of their surrounding subpatch.

    N2V2 version, manipulated pixels are selected randomly away from a grid with an
    approximate uniform probability to be selected across the whole patch.

    If `struct_params` is not None, an additional structN2V mask is applied to the data,
    replacing the pixels in the mask with random values (excluding the pixel already
    manipulated).

    Parameters
    ----------
    batch : torch.Tensor
        Image patch, 2D or 3D, shape (y, x) or (z, y, x).
    mask_pixel_percentage : float
        Approximate percentage of pixels to be masked.
    subpatch_size : int
        Size of the subpatch the new pixel value is sampled from, by default 11.
    struct_params : StructMaskParameters or None, optional
        Parameters for the structN2V mask (axis and span).
    rng : torch.default_generator or None, optional
        Random number generator, by default None.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
           tuple containing the manipulated patch, the original patch and the mask.
    """
    # -- Implementation summary
    # 1. Generate coordinates that correspond to the pixels chosen for masking.
    # 2. Subpatches are extracted, where the coordinate to mask is at the center.
    # 3. The medians of these subpatches are calculated, but we do not want to include
    #    the original pixel in the calculation so we mask it. In the case of StructN2V,
    #    we do not include any pixels in the struct mask in the median calculation.

    if rng is None:
        rng = torch.Generator(device=batch.device)

    # resulting center coord shape: (num_coordinates, batch + num_spatial_dims)
    subpatch_center_coordinates = _get_stratified_coords_torch(
        mask_pixel_percentage, batch.shape, rng
    )
    # pixel coordinates of all the subpatches
    # shape: (num_coordinates, subpatch_size, subpatch_size, ...)
    subpatch_coords = _get_subpatch_coords(
        subpatch_center_coordinates, subpatch_size, batch.shape
    )

    # this indexes and stacks all the subpatches along the first dimension
    # subpatches shape: (num_coordinates, subpatch_size, subpatch_size, ...)
    subpatches = batch[tuple(subpatch_coords)]

    ndims = batch.ndim - 1
    # subpatch mask to exclude values from median calculation
    if struct_params is None:
        subpatch_mask = _create_center_pixel_mask(ndims, subpatch_size, batch.device)
    else:
        subpatch_mask = _create_struct_mask(
            ndims, subpatch_size, struct_params, batch.device
        )
    subpatches_masked = subpatches[:, subpatch_mask]

    medians = subpatches_masked.median(dim=1).values  # (num_coordinates,)

    # Update the output tensor with medians
    output_batch = batch.clone()
    output_batch[tuple(subpatch_center_coordinates.T)] = medians
    mask = (batch != output_batch).to(torch.uint8)

    if struct_params is not None:
        output_batch = _apply_struct_mask_torch(
            output_batch, subpatch_center_coordinates, struct_params, rng
        )

    return output_batch, mask

uniform_manipulate_torch(patch, mask_pixel_percentage, subpatch_size=11, remove_center=True, struct_params=None, rng=None) #

Manipulate pixels by replacing them with a neighbor values.

TODO add more details, especially about batch#

Manipulated pixels are selected uniformly selected in a subpatch, away from a grid with an approximate uniform probability to be selected across the whole patch. If struct_params is not None, an additional structN2V mask is applied to the data, replacing the pixels in the mask with random values (excluding the pixel already manipulated).

Parameters:

Name Type Description Default
patch Tensor

Image patch, 2D or 3D, shape (y, x) or (z, y, x). # TODO batch and channel.

required
mask_pixel_percentage float

Approximate percentage of pixels to be masked.

required
subpatch_size int

Size of the subpatch the new pixel value is sampled from, by default 11.

11
remove_center bool

Whether to remove the center pixel from the subpatch, by default False.

True
struct_params StructMaskParameters or None

Parameters for the structN2V mask (axis and span).

None
rng default_generator or None

Random number generator.

None

Returns:

Type Description
tuple[Tensor, Tensor]

tuple containing the manipulated patch and the corresponding mask.

Source code in src/careamics/transforms/pixel_manipulation_torch.py
def uniform_manipulate_torch(
    patch: torch.Tensor,
    mask_pixel_percentage: float,
    subpatch_size: int = 11,
    remove_center: bool = True,
    struct_params: StructMaskParameters | None = None,
    rng: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Manipulate pixels by replacing them with a neighbor values.

    # TODO add more details, especially about batch

    Manipulated pixels are selected uniformly selected in a subpatch, away from a grid
    with an approximate uniform probability to be selected across the whole patch.
    If `struct_params` is not None, an additional structN2V mask is applied to the
    data, replacing the pixels in the mask with random values (excluding the pixel
    already manipulated).

    Parameters
    ----------
    patch : torch.Tensor
        Image patch, 2D or 3D, shape (y, x) or (z, y, x). # TODO batch and channel.
    mask_pixel_percentage : float
        Approximate percentage of pixels to be masked.
    subpatch_size : int
        Size of the subpatch the new pixel value is sampled from, by default 11.
    remove_center : bool
        Whether to remove the center pixel from the subpatch, by default False.
    struct_params : StructMaskParameters or None
        Parameters for the structN2V mask (axis and span).
    rng : torch.default_generator or None
        Random number generator.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        tuple containing the manipulated patch and the corresponding mask.
    """
    if rng is None:
        rng = torch.Generator(device=patch.device)
        # TODO do we need seed ?

    # create a copy of the patch
    transformed_patch = patch.clone()

    # get the coordinates of the pixels to be masked
    subpatch_centers = _get_stratified_coords_torch(
        mask_pixel_percentage, patch.shape, rng
    )
    subpatch_centers = subpatch_centers.to(device=patch.device)

    # TODO refactor with non negative indices?
    # arrange the list of indices to represent the ROI around the pixel to be masked
    roi_span_full = torch.arange(
        -(subpatch_size // 2),
        subpatch_size // 2 + 1,
        dtype=torch.int32,
        device=patch.device,
    )

    # remove the center pixel from the ROI
    roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full

    # create a random increment to select the replacement value
    # this increment is added to the center coordinates
    random_increment = roi_span[
        torch.randint(
            low=min(roi_span),
            high=max(roi_span) + 1,
            # one less coord dim: we shouldn't add a random increment to the batch coord
            size=(subpatch_centers.shape[0], subpatch_centers.shape[1] - 1),
            generator=rng,
            device=patch.device,
        )
    ]

    # compute the replacement pixel coordinates
    replacement_coords = subpatch_centers.clone()
    # only add random increment to the spatial dimensions, not the batch dimension
    replacement_coords[:, 1:] = torch.clamp(
        replacement_coords[:, 1:] + random_increment,
        torch.zeros_like(torch.tensor(patch.shape[1:])).to(device=patch.device),
        torch.tensor([v - 1 for v in patch.shape[1:]]).to(device=patch.device),
    )

    # replace the pixels in the patch
    # tuples and transpose are needed for proper indexing
    replacement_pixels = patch[tuple(replacement_coords.T)]
    transformed_patch[tuple(subpatch_centers.T)] = replacement_pixels

    # create a mask representing the masked pixels
    mask = (transformed_patch != patch).to(dtype=torch.uint8)

    # apply structN2V mask if needed
    if struct_params is not None:
        transformed_patch = _apply_struct_mask_torch(
            transformed_patch, subpatch_centers, struct_params, rng
        )

    return transformed_patch, mask