Predict Data Module
Prediction Lightning data modules.
PredictDataModule
Bases: LightningDataModule
CAREamics Lightning prediction data module.
The data module can be used with Path, str or numpy arrays. The data can be either a folder containing images or a single file.
To read custom data types, you can set data_type to custom in data_config
and provide a function that returns a numpy array from a path as
read_source_func parameter. The function will receive a Path object and
an axies string as arguments, the axes being derived from the data_config.
You can also provide a fnmatch and Path.rglob compatible expression (e.g.
"*.czi") to filter the files extension using extension_filter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_config
|
InferenceModel
|
Pydantic model for CAREamics prediction configuration. |
required |
pred_data
|
Path or str or ndarray
|
Prediction data, can be a path to a folder, a file or a numpy array. |
required |
read_source_func
|
Callable
|
Function to read custom types, by default None. |
None
|
extension_filter
|
str
|
Filter to filter file extensions for custom types, by default "". |
''
|
dataloader_params
|
dict
|
Dataloader parameters, by default {}. |
None
|
predict_dataloader()
Create a dataloader for prediction.
Returns:
| Type | Description |
|---|---|
DataLoader
|
Prediction dataloader. |
prepare_data()
Hook used to prepare the data before calling setup.
setup(stage=None)
Hook called at the beginning of predict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
Optional[str]
|
Stage, by default None. |
None
|
create_predict_datamodule(pred_data, data_type, axes, image_means, image_stds, tile_size=None, tile_overlap=None, batch_size=1, tta_transforms=True, read_source_func=None, extension_filter='', dataloader_params=None)
Create a CAREamics prediction Lightning datamodule.
This function is used to explicitly pass the parameters usually contained in an
inference_model configuration.
Since the lightning datamodule has no access to the model, make sure that the
parameters passed to the datamodule are consistent with the model's requirements
and are coherent. This can be done by creating a Configuration object beforehand
and passing its parameters to the different Lightning modules.
The data module can be used with Path, str or numpy arrays. To use array data, set
data_type to array and pass a numpy array to train_data.
By default, CAREamics only supports types defined in
careamics.config.support.SupportedData. To read custom data types, you can set
data_type to custom and provide a function that returns a numpy array from a
path. Additionally, pass a fnmatch and Path.rglob compatible expression
(e.g. "*.jpeg") to filter the files extension using extension_filter.
In dataloader_params, you can pass any parameter accepted by PyTorch
dataloaders, except for batch_size, which is set by the batch_size
parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_data
|
str or Path or ndarray
|
Prediction data. |
required |
data_type
|
(array, tiff, custom)
|
Data type, see |
"array"
|
axes
|
str
|
Axes of the data, chosen among SCZYX. |
required |
image_means
|
list of float
|
Mean values for normalization, only used if Normalization is defined. |
required |
image_stds
|
list of float
|
Std values for normalization, only used if Normalization is defined. |
required |
tile_size
|
tuple of int
|
Tile size, 2D or 3D tile size. |
None
|
tile_overlap
|
tuple of int
|
Tile overlap, 2D or 3D tile overlap. |
None
|
batch_size
|
int
|
Batch size. |
1
|
tta_transforms
|
bool
|
Use test time augmentation, by default True. |
True
|
read_source_func
|
Callable
|
Function to read the source data, used if |
None
|
extension_filter
|
str
|
Filter for file extensions, used if |
''
|
dataloader_params
|
dict
|
Pytorch dataloader parameters, by default {}. |
None
|
Returns:
| Type | Description |
|---|---|
PredictDataModule
|
CAREamics prediction datamodule. |
Notes
If you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. If your image has less dimensions, as it may happen in the Z dimension, consider padding your image.