Skip to content

Prediction Writer Callback

Source

A package for the PredictionWriterCallback class and utilities.

CacheTiles

Bases: WriteStrategy

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

Attributes:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

  • tile_cache (list of numpy.ndarray) –

    Tiles cached for stitching prediction.

  • tile_info_cache (list of TileInformation) –

    Cached tile information for stitching prediction.

last_tiles property

List of bool to determine whether each tile in the cache is the last tile.

Returns:

  • list of bool

    Whether each tile in the tile cache is the last tile.

__init__(write_func, write_extension, write_func_kwargs)

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Cache tiles until the last tile is predicted; save the stitched prediction.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning Trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning LightningModule.

  • prediction (Any) –

    Predictions on batch.

  • batch_indices (sequence of int) –

    Indices identifying the samples in the batch.

  • batch (Any) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

  • dirpath (Path) –

    Path to directory to save predictions to.

PredictionWriterCallback

Bases: BasePredictionWriter

A PyTorch Lightning callback to save predictions.

Parameters:

  • write_strategy (WriteStrategy) –

    A strategy for writing predictions.

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

Attributes:

  • write_strategy (WriteStrategy) –

    A strategy for writing predictions.

  • dirpath (pathlib.Path, default="predictions") –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • writing_predictions (bool) –

    If writing predictions is turned on or off.

__init__(write_strategy, dirpath='predictions')

A PyTorch Lightning callback to save predictions.

Parameters:

  • write_strategy (WriteStrategy) –

    A strategy for writing predictions.

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions') classmethod

Initialize a PredictionWriterCallback from write function parameters.

This will automatically create a WriteStrategy to be passed to the initialization of PredictionWriterCallback.

Parameters:

  • write_type (('tiff', 'custom'), default: "tiff" ) –

    The data type to save as, includes custom.

  • tiled (bool) –

    Whether the prediction will be tiled or not.

  • write_func (WriteFunc, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

  • write_extension (str, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

  • write_func_kwargs (dict of {{str: any}}, default: None ) –

    Additional keyword arguments to be passed to the save function.

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

Returns:

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • stage (str) –

    Stage of training e.g. 'predict', 'fit', 'validate'.

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

The method of prediction is determined by the attribute write_strategy.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • prediction (Any) –

    Prediction outputs of batch.

  • batch_indices (sequence of Any) –

    Batch indices.

  • batch (Any) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

WriteImage

Bases: WriteStrategy

A strategy for writing image predictions (i.e. un-tiled predictions).

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

Attributes:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

__init__(write_func, write_extension, write_func_kwargs)

A strategy for writing image predictions (i.e. un-tiled predictions).

Parameters:

  • write_func (WriteFunc) –

    Function used to save predictions.

  • write_extension (str) –

    Extension added to prediction file paths.

  • write_func_kwargs (dict of {str: Any}) –

    Extra kwargs to pass to write_func.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Save full images.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning Trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning LightningModule.

  • prediction (Any) –

    Predictions on batch.

  • batch_indices (sequence of int) –

    Indices identifying the samples in the batch.

  • batch (Any) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

  • dirpath (Path) –

    Path to directory to save predictions to.

Raises:

  • TypeError

    If trainer prediction dataset is not IterablePredDataset.

WriteStrategy

Bases: Protocol

Protocol for write strategy classes.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

WriteStrategy subclasses must contain this function to write a batch.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning Trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning LightningModule.

  • prediction (Any) –

    Predictions on batch.

  • batch_indices (sequence of int) –

    Indices identifying the samples in the batch.

  • batch (Any) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

  • dirpath (Path) –

    Path to directory to save predictions to.

WriteTilesZarr

Bases: WriteStrategy

Strategy to write tiles to Zarr file.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Write tiles to zarr file.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning Trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning LightningModule.

  • prediction (Any) –

    Predictions on batch.

  • batch_indices (sequence of int) –

    Indices identifying the samples in the batch.

  • batch (Any) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.

  • dirpath (Path) –

    Path to directory to save predictions to.

Raises:

create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Create a write strategy from convenient parameters.

Parameters:

  • write_type ((tiff, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • tiled (bool) –

    Whether the prediction will be tiled or not.

  • write_func (WriteFunc, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

  • write_extension (str, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

  • write_func_kwargs (dict of {str: any}, default: None ) –

    Additional keyword arguments to be passed to the save function.

Returns:

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...

The write_func_kwargs will be passed to the write_func doing the following:

write_func(file_path=file_path, img=img, **kwargs)

get_sample_file_path(dataset, sample_id)

Get the file path for a particular sample.

Parameters:

Returns:

  • Path

    The file path corresponding to the sample with the ID sample_id.

select_write_extension(write_type, write_extension=None)

Return an extension to add to file paths.

If write_type is "custom" then write_extension, otherwise the known write extension is selected.

Parameters:

  • write_type ((tiff, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • write_extension (str, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

Returns:

  • str

    The extension to be added to file paths.

Raises:

  • ValueError

    If self.save_type="custom" but save_extension has not been given.

select_write_func(write_type, write_func=None)

Return a function to write images.

If write_type is "custom" then write_func, otherwise the known write function is selected.

Parameters:

  • write_type ((tiff, custom), default: "tiff" ) –

    The data type to save as, includes custom.

  • write_func (WriteFunc, default: None ) –

    If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

Returns:

  • WriteFunc

    A function for writing images.

Raises:

  • ValueError

    If write_type="custom" but write_func has not been given.

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...