Prediction Writer Callback
A package for the PredictionWriterCallback class and utilities.
CacheTiles
Bases: WriteStrategy
A write strategy that will cache tiles.
Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.
Parameters:
-
write_func(WriteFunc) –Function used to save predictions.
-
write_extension(str) –Extension added to prediction file paths.
-
write_func_kwargs(dict of {str: Any}) –Extra kwargs to pass to
write_func.
Attributes:
-
write_func(WriteFunc) –Function used to save predictions.
-
write_extension(str) –Extension added to prediction file paths.
-
write_func_kwargs(dict of {str: Any}) –Extra kwargs to pass to
write_func. -
tile_cache(list of numpy.ndarray) –Tiles cached for stitching prediction.
-
tile_info_cache(list of TileInformation) –Cached tile information for stitching prediction.
last_tiles
property
List of bool to determine whether each tile in the cache is the last tile.
Returns:
-
list of bool–Whether each tile in the tile cache is the last tile.
__init__(write_func, write_extension, write_func_kwargs)
A write strategy that will cache tiles.
Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.
Parameters:
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Cache tiles until the last tile is predicted; save the stitched prediction.
Parameters:
-
trainer(Trainer) –PyTorch Lightning Trainer.
-
pl_module(LightningModule) –PyTorch Lightning LightningModule.
-
prediction(Any) –Predictions on
batch. -
batch_indices(sequence of int) –Indices identifying the samples in the batch.
-
batch(Any) –Input batch.
-
batch_idx(int) –Batch index.
-
dataloader_idx(int) –Dataloader index.
-
dirpath(Path) –Path to directory to save predictions to.
PredictionWriterCallback
Bases: BasePredictionWriter
A PyTorch Lightning callback to save predictions.
Parameters:
-
write_strategy(WriteStrategy) –A strategy for writing predictions.
-
dirpath(Path or str, default:"predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory.
Attributes:
-
write_strategy(WriteStrategy) –A strategy for writing predictions.
-
dirpath(pathlib.Path, default="predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory. -
writing_predictions(bool) –If writing predictions is turned on or off.
__init__(write_strategy, dirpath='predictions')
A PyTorch Lightning callback to save predictions.
Parameters:
-
write_strategy(WriteStrategy) –A strategy for writing predictions.
-
dirpath(Path or str, default:"predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory.
from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions')
classmethod
Initialize a PredictionWriterCallback from write function parameters.
This will automatically create a WriteStrategy to be passed to the
initialization of PredictionWriterCallback.
Parameters:
-
write_type(('tiff', 'custom'), default:"tiff") –The data type to save as, includes custom.
-
tiled(bool) –Whether the prediction will be tiled or not.
-
write_func(WriteFunc, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typea function to save the data must be passed. See notes below. -
write_extension(str, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typean extension to save the data with must be passed. -
write_func_kwargs(dict of {{str: any}}, default:None) –Additional keyword arguments to be passed to the save function.
-
dirpath(Path or str, default:"predictions") –The path to the directory where prediction outputs will be saved. If
dirpathis not absolute it is assumed to be relative to current working directory.
Returns:
-
PredictionWriterCallback–Callback for writing predictions.
setup(trainer, pl_module, stage)
Create the prediction output directory when predict begins.
Called when fit, validate, test, predict, or tune begins.
Parameters:
-
trainer(Trainer) –PyTorch Lightning trainer.
-
pl_module(LightningModule) –PyTorch Lightning module.
-
stage(str) –Stage of training e.g. 'predict', 'fit', 'validate'.
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
Write predictions at the end of a batch.
The method of prediction is determined by the attribute write_strategy.
Parameters:
WriteImage
Bases: WriteStrategy
A strategy for writing image predictions (i.e. un-tiled predictions).
Parameters:
-
write_func(WriteFunc) –Function used to save predictions.
-
write_extension(str) –Extension added to prediction file paths.
-
write_func_kwargs(dict of {str: Any}) –Extra kwargs to pass to
write_func.
Attributes:
-
write_func(WriteFunc) –Function used to save predictions.
-
write_extension(str) –Extension added to prediction file paths.
-
write_func_kwargs(dict of {str: Any}) –Extra kwargs to pass to
write_func.
__init__(write_func, write_extension, write_func_kwargs)
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Save full images.
Parameters:
-
trainer(Trainer) –PyTorch Lightning Trainer.
-
pl_module(LightningModule) –PyTorch Lightning LightningModule.
-
prediction(Any) –Predictions on
batch. -
batch_indices(sequence of int) –Indices identifying the samples in the batch.
-
batch(Any) –Input batch.
-
batch_idx(int) –Batch index.
-
dataloader_idx(int) –Dataloader index.
-
dirpath(Path) –Path to directory to save predictions to.
Raises:
-
TypeError–If trainer prediction dataset is not
IterablePredDataset.
WriteStrategy
Bases: Protocol
Protocol for write strategy classes.
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
WriteStrategy subclasses must contain this function to write a batch.
Parameters:
-
trainer(Trainer) –PyTorch Lightning Trainer.
-
pl_module(LightningModule) –PyTorch Lightning LightningModule.
-
prediction(Any) –Predictions on
batch. -
batch_indices(sequence of int) –Indices identifying the samples in the batch.
-
batch(Any) –Input batch.
-
batch_idx(int) –Batch index.
-
dataloader_idx(int) –Dataloader index.
-
dirpath(Path) –Path to directory to save predictions to.
WriteTilesZarr
Bases: WriteStrategy
Strategy to write tiles to Zarr file.
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Write tiles to zarr file.
Parameters:
-
trainer(Trainer) –PyTorch Lightning Trainer.
-
pl_module(LightningModule) –PyTorch Lightning LightningModule.
-
prediction(Any) –Predictions on
batch. -
batch_indices(sequence of int) –Indices identifying the samples in the batch.
-
batch(Any) –Input batch.
-
batch_idx(int) –Batch index.
-
dataloader_idx(int) –Dataloader index.
-
dirpath(Path) –Path to directory to save predictions to.
Raises:
create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)
Create a write strategy from convenient parameters.
Parameters:
-
write_type((tiff, custom), default:"tiff") –The data type to save as, includes custom.
-
tiled(bool) –Whether the prediction will be tiled or not.
-
write_func(WriteFunc, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typea function to save the data must be passed. See notes below. -
write_extension(str, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typean extension to save the data with must be passed. -
write_func_kwargs(dict of {str: any}, default:None) –Additional keyword arguments to be passed to the save function.
Returns:
-
WriteStrategy–A strategy for writing predicions.
get_sample_file_path(dataset, sample_id)
Get the file path for a particular sample.
Parameters:
-
dataset(IterableTiledPredDataset or IterablePredDataset) –Dataset.
-
sample_id(int) –Sample ID, the index of the file in the dataset
dataset.
Returns:
-
Path–The file path corresponding to the sample with the ID
sample_id.
select_write_extension(write_type, write_extension=None)
Return an extension to add to file paths.
If write_type is "custom" then write_extension, otherwise the known
write extension is selected.
Parameters:
-
write_type((tiff, custom), default:"tiff") –The data type to save as, includes custom.
-
write_extension(str, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typean extension to save the data with must be passed.
Returns:
-
str–The extension to be added to file paths.
Raises:
-
ValueError–If
self.save_type="custom"butsave_extensionhas not been given.
select_write_func(write_type, write_func=None)
Return a function to write images.
If write_type is "custom" then write_func, otherwise the known write function
is selected.
Parameters:
-
write_type((tiff, custom), default:"tiff") –The data type to save as, includes custom.
-
write_func(WriteFunc, default:None) –If a known
write_typeis selected this argument is ignored. For a customwrite_typea function to save the data must be passed. See notes below.
Returns:
-
WriteFunc–A function for writing images.
Raises:
-
ValueError–If
write_type="custom"butwrite_funchas not been given.