Prediction Writer Callback
A package for the PredictionWriterCallback class and utilities.
CacheTiles
Bases: WriteStrategy
A write strategy that will cache tiles.
Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_func
|
WriteFunc
|
Function used to save predictions. |
required |
write_extension
|
str
|
Extension added to prediction file paths. |
required |
write_func_kwargs
|
dict of {str: Any}
|
Extra kwargs to pass to |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
write_func |
WriteFunc
|
Function used to save predictions. |
write_extension |
str
|
Extension added to prediction file paths. |
write_func_kwargs |
dict of {str: Any}
|
Extra kwargs to pass to |
tile_cache |
list of numpy.ndarray
|
Tiles cached for stitching prediction. |
tile_info_cache |
list of TileInformation
|
Cached tile information for stitching prediction. |
last_tiles
property
List of bool to determine whether each tile in the cache is the last tile.
Returns:
| Type | Description |
|---|---|
list of bool
|
Whether each tile in the tile cache is the last tile. |
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Cache tiles until the last tile is predicted; save the stitched prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning Trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning LightningModule. |
required |
prediction
|
Any
|
Predictions on |
required |
batch_indices
|
sequence of int
|
Indices identifying the samples in the batch. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
dirpath
|
Path
|
Path to directory to save predictions to. |
required |
PredictionWriterCallback
Bases: BasePredictionWriter
A PyTorch Lightning callback to save predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_strategy
|
WriteStrategy
|
A strategy for writing predictions. |
required |
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Attributes:
| Name | Type | Description |
|---|---|---|
write_strategy |
WriteStrategy
|
A strategy for writing predictions. |
dirpath |
pathlib.Path, default="predictions"
|
The path to the directory where prediction outputs will be saved. If
|
writing_predictions |
bool
|
If writing predictions is turned on or off. |
from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions')
classmethod
Initialize a PredictionWriterCallback from write function parameters.
This will automatically create a WriteStrategy to be passed to the
initialization of PredictionWriterCallback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
('tiff', 'custom')
|
The data type to save as, includes custom. |
"tiff"
|
tiled
|
bool
|
Whether the prediction will be tiled or not. |
required |
write_func
|
WriteFunc
|
If a known |
None
|
write_extension
|
str
|
If a known |
None
|
write_func_kwargs
|
dict of {{str: any}}
|
Additional keyword arguments to be passed to the save function. |
None
|
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Returns:
| Type | Description |
|---|---|
PredictionWriterCallback
|
Callback for writing predictions. |
setup(trainer, pl_module, stage)
Create the prediction output directory when predict begins.
Called when fit, validate, test, predict, or tune begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
stage
|
str
|
Stage of training e.g. 'predict', 'fit', 'validate'. |
required |
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
Write predictions at the end of a batch.
The method of prediction is determined by the attribute write_strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
prediction
|
Any
|
Prediction outputs of |
required |
batch_indices
|
sequence of Any
|
Batch indices. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
WriteImage
Bases: WriteStrategy
A strategy for writing image predictions (i.e. un-tiled predictions).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_func
|
WriteFunc
|
Function used to save predictions. |
required |
write_extension
|
str
|
Extension added to prediction file paths. |
required |
write_func_kwargs
|
dict of {str: Any}
|
Extra kwargs to pass to |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
write_func |
WriteFunc
|
Function used to save predictions. |
write_extension |
str
|
Extension added to prediction file paths. |
write_func_kwargs |
dict of {str: Any}
|
Extra kwargs to pass to |
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Save full images.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning Trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning LightningModule. |
required |
prediction
|
Any
|
Predictions on |
required |
batch_indices
|
sequence of int
|
Indices identifying the samples in the batch. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
dirpath
|
Path
|
Path to directory to save predictions to. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If trainer prediction dataset is not |
WriteStrategy
Bases: Protocol
Protocol for write strategy classes.
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
WriteStrategy subclasses must contain this function to write a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning Trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning LightningModule. |
required |
prediction
|
Any
|
Predictions on |
required |
batch_indices
|
sequence of int
|
Indices identifying the samples in the batch. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
dirpath
|
Path
|
Path to directory to save predictions to. |
required |
WriteTilesZarr
Bases: WriteStrategy
Strategy to write tiles to Zarr file.
write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)
Write tiles to zarr file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning Trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning LightningModule. |
required |
prediction
|
Any
|
Predictions on |
required |
batch_indices
|
sequence of int
|
Indices identifying the samples in the batch. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |
dirpath
|
Path
|
Path to directory to save predictions to. |
required |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
|
create_write_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)
Create a write strategy from convenient parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
(tiff, custom)
|
The data type to save as, includes custom. |
"tiff"
|
tiled
|
bool
|
Whether the prediction will be tiled or not. |
required |
write_func
|
WriteFunc
|
If a known |
None
|
write_extension
|
str
|
If a known |
None
|
write_func_kwargs
|
dict of {str: any}
|
Additional keyword arguments to be passed to the save function. |
None
|
Returns:
| Type | Description |
|---|---|
WriteStrategy
|
A strategy for writing predicions. |
select_write_extension(write_type, write_extension=None)
Return an extension to add to file paths.
If write_type is "custom" then write_extension, otherwise the known
write extension is selected.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
(tiff, custom)
|
The data type to save as, includes custom. |
"tiff"
|
write_extension
|
str
|
If a known |
None
|
Returns:
| Type | Description |
|---|---|
str
|
The extension to be added to file paths. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
select_write_func(write_type, write_func=None)
Return a function to write images.
If write_type is "custom" then write_func, otherwise the known write function
is selected.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
(tiff, custom)
|
The data type to save as, includes custom. |
"tiff"
|
write_func
|
WriteFunc
|
If a known |
None
|
Returns:
| Type | Description |
|---|---|
WriteFunc
|
A function for writing images. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |