Skip to content

Prediction

Source

A package for the PredictionWriterCallback class and utilities.

ImageWriteStrategy

Bases: WriteStrategy

A strategy for writing whole image predictions (i.e. un-tiled predictions).

Predictions are cached until all samples for a given data_idx are collected, then combined and written. This prevents overwrites when S_dim > batch_size.

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

Attributes:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

  • image_cache (dict of {int: list of ImageRegionData}) –

    Cache for predictions across batches, keyed by data_idx.

__init__(write_func, write_extension, write_func_kwargs)

A strategy for writing image predictions (i.e. un-tiled predictions).

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

write_batch(dirpath, predictions)

Cache predictions and save full images.

Predictions are cached by data_idx until all samples (S dimension) are collected, then combined and written.

Parameters:

  • dirpath (Path) –

    Path to directory to save predictions to.

  • predictions (list[ImageRegionData]) –

    Decollated predictions.

PredictionWriterCallback

Bases: BasePredictionWriter

PyTorch Lightning callback to save predictions.

A WriteStrategy must be provided at instantiation or later via set_writing_strategy. This allows passing the callback to the Lightning Trainer before knowing what writing strategy (e.g. tiling or file type) will be used.

By default the prediction writer is enabled, but it can be disabled by passing enable_writing=False at instantiation or by calling enable_writing(False). Subsequently, writing can be re-enabled by calling enable_writing(True).

Parameters:

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None, default: None ) –

    A strategy for writing predictions.

  • enable_writing (bool, default: True ) –

    If writing predictions should be enabled by default.

Attributes:

  • is_enabled (bool) –

    Whether writing predictions is enabled.

  • dirpath (pathlib.Path, default="") –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None) –

    A strategy for writing predictions.

__init__(dirpath='', write_strategy=None, enable_writing=True)

Constructor.

A WriteStrategy must be provided at instantiation or later via set_writing_strategy.

Parameters:

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None, default: None ) –

    A strategy for writing predictions.

  • enable_writing (bool, default: True ) –

    If writing predictions should be enabled by default.

enable_writing(enable_writing)

Enable or disable writing.

Parameters:

  • enable_writing (bool) –

    If writing predictions should be enabled.

set_writing_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Set the writing strategy.

Must be called before writing predictions.

Parameters:

  • write_type (SupportedWriteType) –

    The type of writing to perform.

  • tiled (bool) –

    Whether to write in tiled format.

  • write_func (WriteFunc or None, default: None ) –

    A custom writing function.

  • write_extension (str or None, default: None ) –

    The file extension to use when writing files.

  • write_func_kwargs (dict of str to Any, default: None ) –

    Additional keyword arguments to pass to write_func.

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • stage (str) –

    Stage of training e.g. 'predict', 'fit', 'validate'.

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

Writing method is determined by the attribute write_strategy.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • prediction (ImageRegionData) –

    Prediction outputs of batch.

  • batch_indices (sequence of Any) –

    Batch indices.

  • batch (ImageRegionData) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

TileWriteStrategy

Bases: WriteStrategy

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

Attributes:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

  • tile_cache (dict of {int: list of ImageRegionData}) –

    Tiles cached for stitching prediction.

  • tile_info_cache (list of TileInformation) –

    Cached tile information for stitching prediction.

__init__(write_func, write_extension, write_func_kwargs)

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

write_batch(dirpath, predictions)

Cache tiles until the last tile is predicted, then save the stitched image.

Parameters:

  • dirpath (Path) –

    Path to directory to save predictions to.

  • predictions (list[ImageRegionData]) –

    Decollated predictions.

WriteStrategy

Bases: Protocol

Protocol for write strategy classes.

write_batch(dirpath, predictions)

WriteStrategy subclasses must contain this function to write a batch.

Parameters:

  • dirpath (Path) –

    Path to directory to save predictions to.

  • predictions (list[ImageRegionData]) –

    Decollated predictions.

ZarrTileWriteStrategy

Bases: WriteStrategy

Zarr tile writer strategy.

This writer creates zarr files, groups and arrays as needed and writes tiles into the appropriate locations.

__init__()

Constructor.

write_batch(dirpath, predictions)

Write all tiles to a Zarr file.

Parameters:

  • dirpath (Path) –

    Path to directory to save predictions to.

  • predictions (list[ImageRegionData]) –

    Decollated predictions.

write_tile(dirpath, region)

Write cropped tile to zarr array.

Parameters:

  • dirpath (Path) –

    Path to directory to save predictions to.

  • region (ImageRegionData) –

    Image region data containing tile information.

create_write_file_path(dirpath, file_path, write_extension, postfix='')

Create the file name for the output file.

Takes the original file path, changes the directory to dirpath and changes the extension to write_extension.

Parameters:

  • dirpath (Path) –

    The output directory to write file to.

  • file_path (Path) –

    The original file path.

  • write_extension (str) –

    The extension that output files should have.

  • postfix (str, default: "" ) –

    Appends to filename before extension.

Returns:

  • Path

    The output file path.

create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Create a write strategy from convenient parameters.

Parameters:

  • write_type ((tiff, zarr, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • tiled (bool) –

    Whether the prediction will be tiled or not.

  • write_func (WriteFunc, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

  • write_extension (str, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

  • write_func_kwargs (dict of {str: any}, default: None ) –

    Additional keyword arguments to be passed to the save function.

Returns:

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...

The write_func_kwargs will be passed to the write_func doing the following:

write_func(file_path=file_path, img=img, **kwargs)

decollate_image_region_data(batch)

Decollate a batch of ImageRegionData into a list of ImageRegionData.

Input batch has the following structure: - data: (B, C, (Z), Y, X) numpy.ndarray - source: sequence of str, length B - data_shape: sequence of tuple of int, each tuple being of length B - dtype: list of numpy.dtype, length B - axes: list of str, length B - region_spec: dict of {str: sequence}, each sequence being of length B - additional_metadata: dict of {str: Any}, each sequence being of length B

Parameters:

Returns:

  • list of ImageRegionData

    List of ImageRegionData.

select_write_extension(write_type, write_extension=None)

Return an extension to add to file paths.

If write_type is "custom" then write_extension, otherwise the known write extension is selected.

Parameters:

  • write_type ((tiff, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • write_extension (str, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

Returns:

  • str

    The extension to be added to file paths.

Raises:

  • ValueError

    If self.save_type="custom" but save_extension has not been given.

select_write_func(write_type, write_func=None)

Return a function to write images.

If write_type is "custom" then write_func, otherwise the known write function is selected.

Parameters:

  • write_type ((tiff, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • write_func (WriteFunc, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

Returns:

  • WriteFunc

    A function for writing images.

Raises:

  • ValueError

    If write_type="custom" but write_func has not been given.

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...