Skip to content

Train Data Module

Source

Training and validation Lightning data modules.

TrainDataModule

Bases: LightningDataModule

CAREamics Ligthning training and validation data module.

The data module can be used with Path, str or numpy arrays. In the case of numpy arrays, it loads and computes all the patches in memory. For Path and str inputs, it calculates the total file size and estimate whether it can fit in memory. If it does not, it iterates through the files. This behaviour can be deactivated by setting use_in_memory to False, in which case it will always use the iterating dataset to train on a Path or str.

The data can be either a folder containing images or a single file.

Validation can be omitted, in which case the validation data is extracted from the training data. The percentage of the training data to use for validation, as well as the minimum number of patches or files to split from the training data can be set using val_percentage and val_minimum_split, respectively.

To read custom data types, you can set data_type to custom in data_config and provide a function that returns a numpy array from a path as read_source_func parameter. The function will receive a Path object and an axies string as arguments, the axes being derived from the data_config.

You can also provide a fnmatch and Path.rglob compatible expression (e.g. "*.czi") to filter the files extension using extension_filter.

Parameters:

Name Type Description Default
data_config DataModel

Pydantic model for CAREamics data configuration.

required
train_data Path or str or ndarray

Training data, can be a path to a folder, a file or a numpy array.

required
val_data Path or str or ndarray

Validation data, can be a path to a folder, a file or a numpy array, by default None.

None
train_data_target Path or str or ndarray

Training target data, can be a path to a folder, a file or a numpy array, by default None.

None
val_data_target Path or str or ndarray

Validation target data, can be a path to a folder, a file or a numpy array, by default None.

None
read_source_func Callable

Function to read the source data, by default None. Only used for custom data type (see DataModel).

None
extension_filter str

Filter for file extensions, by default "". Only used for custom data types (see DataModel).

''
val_percentage float

Percentage of the training data to use for validation, by default 0.1. Only used if val_data is None.

0.1
val_minimum_split int

Minimum number of patches or files to split from the training data for validation, by default 5. Only used if val_data is None.

5
use_in_memory bool

Use in memory dataset if possible, by default True.

True

Attributes:

Name Type Description
data_config DataModel

CAREamics data configuration.

data_type SupportedData

Expected data type, one of "tiff", "array" or "custom".

batch_size int

Batch size.

use_in_memory bool

Whether to use in memory dataset if possible.

train_data Path or ndarray

Training data.

val_data Path or ndarray

Validation data.

train_data_target Path or ndarray

Training target data.

val_data_target Path or ndarray

Validation target data.

val_percentage float

Percentage of the training data to use for validation, if no validation data is provided.

val_minimum_split int

Minimum number of patches or files to split from the training data for validation, if no validation data is provided.

read_source_func Optional[Callable]

Function to read the source data, used if data_type is custom.

extension_filter str

Filter for file extensions, used if data_type is custom.

get_data_statistics()

Return training data statistics.

Returns:

Type Description
tuple of list

Means and standard deviations across channels of the training data.

prepare_data()

Hook used to prepare the data before calling setup.

Here, we only need to examine the data if it was provided as a str or a Path.

TODO: from lightning doc: prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won't be available for other processes.

https://lightning.ai/docs/pytorch/stable/data/datamodule.html

setup(*args, **kwargs)

Hook called at the beginning of fit, validate, or predict.

Parameters:

Name Type Description Default
*args Any

Unused.

()
**kwargs Any

Unused.

{}

train_dataloader()

Create a dataloader for training.

Returns:

Type Description
Any

Training dataloader.

val_dataloader()

Create a dataloader for validation.

Returns:

Type Description
Any

Validation dataloader.

create_train_datamodule(train_data, data_type, patch_size, axes, batch_size, val_data=None, transforms=None, train_target_data=None, val_target_data=None, read_source_func=None, extension_filter='', val_percentage=0.1, val_minimum_patches=5, train_dataloader_params=None, val_dataloader_params=None, use_in_memory=True)

Create a TrainDataModule.

This function is used to explicitly pass the parameters usually contained in a GenericDataConfig to a TrainDataModule.

Since the lightning datamodule has no access to the model, make sure that the parameters passed to the datamodule are consistent with the model's requirements and are coherent.

The default augmentations are XY flip and XY rotation. To use a different set of augmentations, you can pass a list of transforms to transforms.

The data module can be used with Path, str or numpy arrays. In the case of numpy arrays, it loads and computes all the patches in memory. For Path and str inputs, it calculates the total file size and estimate whether it can fit in memory. If it does not, it iterates through the files. This behaviour can be deactivated by setting use_in_memory to False, in which case it will always use the iterating dataset to train on a Path or str.

To use array data, set data_type to array and pass a numpy array to train_data.

By default, CAREamics only supports types defined in careamics.config.support.SupportedData. To read custom data types, you can set data_type to custom and provide a function that returns a numpy array from a path. Additionally, pass a fnmatch and Path.rglob compatible expression (e.g. "*.jpeg") to filter the files extension using extension_filter.

In the absence of validation data, the validation data is extracted from the training data. The percentage of the training data to use for validation, as well as the minimum number of patches to split from the training data for validation can be set using val_percentage and val_minimum_patches, respectively.

In dataloader_params, you can pass any parameter accepted by PyTorch dataloaders, except for batch_size, which is set by the batch_size parameter.

Parameters:

Name Type Description Default
train_data Path or str or ndarray

Training data.

required
data_type (array, tiff, custom)

Data type, see SupportedData for available options.

"array"
patch_size list of int

Patch size, 2D or 3D patch size.

required
axes str

Axes of the data, chosen amongst SCZYX.

required
batch_size int

Batch size.

required
val_data Path or str or ndarray

Validation data, by default None.

None
transforms list of Transforms

List of transforms to apply to training patches. If None, default transforms are applied.

None
train_target_data Path or str or ndarray

Training target data, by default None.

None
val_target_data Path or str or ndarray

Validation target data, by default None.

None
read_source_func Callable

Function to read the source data, used if data_type is custom, by default None.

None
extension_filter str

Filter for file extensions, used if data_type is custom, by default "".

''
val_percentage float

Percentage of the training data to use for validation if no validation data is given, by default 0.1.

0.1
val_minimum_patches int

Minimum number of patches to split from the training data for validation if no validation data is given, by default 5.

5
train_dataloader_params dict

Pytorch dataloader parameters for the training data, by default {}.

None
val_dataloader_params dict

Pytorch dataloader parameters for the validation data, by default {}.

None
use_in_memory bool

Use in memory dataset if possible, by default True.

True

Returns:

Type Description
TrainDataModule

CAREamics training Lightning data module.

Examples:

Create a TrainingDataModule with default transforms with a numpy array:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> my_array = np.arange(256).reshape(16, 16)
>>> data_module = create_train_datamodule(
...     train_data=my_array,
...     data_type="array",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
... )

For custom data types (those not supported by CAREamics), then one can pass a read function and a filter for the files extension:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>>
>>> def read_npy(path):
...     return np.load(path)
>>>
>>> data_module = create_train_datamodule(
...     train_data="path/to/data",
...     data_type="custom",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
...     read_source_func=read_npy,
...     extension_filter="*.npy",
... )

If you want to use a different set of augmentations, you can pass a list of transforms:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> from careamics.config.augmentations import XYFlipConfig
>>> from careamics.config.support import SupportedTransform
>>> my_array = np.arange(256).reshape(16, 16)
>>> my_transforms = [
...     XYFlipConfig(flip_y=False),
... ]
>>> data_module = create_train_datamodule(
...     train_data=my_array,
...     data_type="array",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
...     transforms=my_transforms,
... )