Skip to content

Prediction Writer Callback

Source

Module containing PredictionWriterCallback class.

PredictionWriterCallback

Bases: BasePredictionWriter

PyTorch Lightning callback to save predictions.

A WriteStrategy must be provided at instantiation or later via set_writing_strategy. This allows passing the callback to the Lightning Trainer before knowing what writing strategy (e.g. tiling or file type) will be used.

By default the prediction writer is enabled, but it can be disabled by passing enable_writing=False at instantiation or by calling enable_writing(False). Subsequently, writing can be re-enabled by calling enable_writing(True).

Parameters:

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None, default: None ) –

    A strategy for writing predictions.

  • enable_writing (bool, default: True ) –

    If writing predictions should be enabled by default.

Attributes:

  • is_enabled (bool) –

    Whether writing predictions is enabled.

  • dirpath (pathlib.Path, default="") –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None) –

    A strategy for writing predictions.

__init__(dirpath='', write_strategy=None, enable_writing=True)

Constructor.

A WriteStrategy must be provided at instantiation or later via set_writing_strategy.

Parameters:

  • dirpath (Path or str, default: "predictions" ) –

    The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

  • write_strategy (WriteStrategy or None, default: None ) –

    A strategy for writing predictions.

  • enable_writing (bool, default: True ) –

    If writing predictions should be enabled by default.

enable_writing(enable_writing)

Enable or disable writing.

Parameters:

  • enable_writing (bool) –

    If writing predictions should be enabled.

set_writing_strategy(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None)

Set the writing strategy.

Must be called before writing predictions.

Parameters:

  • write_type (SupportedWriteType) –

    The type of writing to perform.

  • tiled (bool) –

    Whether to write in tiled format.

  • write_func (WriteFunc or None, default: None ) –

    A custom writing function.

  • write_extension (str or None, default: None ) –

    The file extension to use when writing files.

  • write_func_kwargs (dict of str to Any, default: None ) –

    Additional keyword arguments to pass to write_func.

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • stage (str) –

    Stage of training e.g. 'predict', 'fit', 'validate'.

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

Writing method is determined by the attribute write_strategy.

Parameters:

  • trainer (Trainer) –

    PyTorch Lightning trainer.

  • pl_module (LightningModule) –

    PyTorch Lightning module.

  • prediction (ImageRegionData) –

    Prediction outputs of batch.

  • batch_indices (sequence of Any) –

    Batch indices.

  • batch (ImageRegionData) –

    Input batch.

  • batch_idx (int) –

    Batch index.

  • dataloader_idx (int) –

    Dataloader index.