Skip to content

Write Strategy

Source

Module containing different strategies for writing predictions.

CacheTiles

Bases: WriteStrategy

A write strategy that will cache tiles.

Tiles are cached until a whole image is predicted on. Then the stitched prediction is saved.

Parameters:

Name Type Description Default
write_func WriteFunc

Function used to save predictions.

required
write_extension str

Extension added to prediction file paths.

required
write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

required

Attributes:

Name Type Description
write_func WriteFunc

Function used to save predictions.

write_extension str

Extension added to prediction file paths.

write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

tile_cache list of numpy.ndarray

Tiles cached for stitching prediction.

tile_info_cache list of TileInformation

Cached tile information for stitching prediction.

last_tiles property

List of bool to determine whether each tile in the cache is the last tile.

Returns:

Type Description
list of bool

Whether each tile in the tile cache is the last tile.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Cache tiles until the last tile is predicted; save the stitched prediction.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

WriteImage

Bases: WriteStrategy

A strategy for writing image predictions (i.e. un-tiled predictions).

Parameters:

Name Type Description Default
write_func WriteFunc

Function used to save predictions.

required
write_extension str

Extension added to prediction file paths.

required
write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

required

Attributes:

Name Type Description
write_func WriteFunc

Function used to save predictions.

write_extension str

Extension added to prediction file paths.

write_func_kwargs dict of {str: Any}

Extra kwargs to pass to write_func.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Save full images.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

Raises:

Type Description
TypeError

If trainer prediction dataset is not IterablePredDataset.

WriteStrategy

Bases: Protocol

Protocol for write strategy classes.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

WriteStrategy subclasses must contain this function to write a batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

WriteTilesZarr

Bases: WriteStrategy

Strategy to write tiles to Zarr file.

write_batch(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx, dirpath)

Write tiles to zarr file.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning Trainer.

required
pl_module LightningModule

PyTorch Lightning LightningModule.

required
prediction Any

Predictions on batch.

required
batch_indices sequence of int

Indices identifying the samples in the batch.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required
dirpath Path

Path to directory to save predictions to.

required

Raises:

Type Description
NotImplementedError