Skip to content

Prediction Writer Callback

Source

A package for the PredictionWriterCallback class and utilities.

CacheTiles

Bases: WriteStrategy

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

Name Type Description Default
write_func WriteFunc

Function used to save predictions.

required
write_extension str

Extension added to prediction file paths.

required
write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

required

Attributes:

Name Type Description
write_func WriteFunc

Function used to save predictions.

write_extension str

Extension added to prediction file paths.

write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

tile_cache list of numpy.ndarray

Tiles cached for stitching prediction.

tile_info_cache list of TileInformation

Cached tile information for stitching prediction.

last_tiles property

List of bool to determine whether each tile in the cache is the last tile.

Returns:

Type Description
list of bool

Whether each tile in the tile cache is the last tile.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Cache tiles until the last tile is predicted; save the stitched prediction.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

PredictionWriterCallback

Bases: BasePredictionWriter

A PyTorch Lightning callback to save predictions.

Parameters:

Name Type Description Default
write_strategy WriteStrategy

A strategy for writing predictions.

required
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Attributes:

Name Type Description
write_strategy WriteStrategy

A strategy for writing predictions.

dirpath pathlib.Path, default="predictions"

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

writing_predictions bool

If writing predictions is turned on or off.

from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions') classmethod

Initialize a PredictionWriterCallback from write function parameters.

This will automatically create a WriteStrategy to be passed to the initialization of PredictionWriterCallback.

Parameters:

Name Type Description Default
write_type ('tiff', 'custom')

The data type to save as, includes custom.

"tiff"
tiled bool

Whether the prediction will be tiled or not.

required
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func_kwargs dict of {{str: any}}

Additional keyword arguments to be passed to the save function.

None
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Returns:

Type Description
PredictionWriterCallback

Callback for writing predictions.

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
stage str

Stage of training e.g. 'predict', 'fit', 'validate'.

required

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

The method of prediction is determined by the attribute write_strategy.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
prediction Any

Prediction outputs of batch.

required
batch_indices sequence of Any

Batch indices.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required

WriteImage

Bases: WriteStrategy

A strategy for writing image predictions (i.e. un-tiled predictions).

Parameters:

Name Type Description Default
write_func WriteFunc

Function used to save predictions.

required
write_extension str

Extension added to prediction file paths.

required
write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

required

Attributes:

Name Type Description
write_func WriteFunc

Function used to save predictions.

write_extension str

Extension added to prediction file paths.

write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Save full images.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

Raises:

Type Description
TypeError

If trainer prediction dataset is not IterablePredDataset.

WriteStrategy

Bases: Protocol

Protocol for write strategy classes.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

WriteStrategy subclasses must contain this function to write a batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

WriteTilesZarr

Bases: WriteStrategy

Strategy to write tiles to Zarr file.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Write tiles to zarr file.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

Raises:

Type Description
NotImplementedError

create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Create a write strategy from convenient parameters.

Parameters:

Name Type Description Default
write_type (tiff, custom)

The data type to save as, includes custom.

"tiff"
tiled bool

Whether the prediction will be tiled or not.

required
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func_kwargs dict of {str: any}

Additional keyword arguments to be passed to the save function.

None

Returns:

Type Description
WriteStrategy

A strategy for writing predicions.

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...

The write_func_kwargs will be passed to the write_func doing the following:

write_func(file_path=file_path, img=img, **kwargs)

select_write_extension(write_type, write_extension=None)

Return an extension to add to file paths.

If write_type is "custom" then write_extension, otherwise the known write extension is selected.

Parameters:

Name Type Description Default
write_type (tiff, custom)

The data type to save as, includes custom.

"tiff"
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None

Returns:

Type Description
str

The extension to be added to file paths.

Raises:

Type Description
ValueError

If self.save_type="custom" but save_extension has not been given.

select_write_func(write_type, write_func=None)

Return a function to write images.

If write_type is "custom" then write_func, otherwise the known write function is selected.

Parameters:

Name Type Description Default
write_type (tiff, custom)

The data type to save as, includes custom.

"tiff"
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None

Returns:

Type Description
WriteFunc

A function for writing images.

Raises:

Type Description
ValueError

If write_type="custom" but write_func has not been given.

Notes

The write_func function signature must match that of the example below

write_func(file_path: Path, img: NDArray, *args, **kwargs) -> None: ...