Skip to content

Lightning

Source

CAREamics PyTorch Lightning modules.

DataStatsCallback

Bases: Callback

Callback to update model's data statistics from datamodule.

This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.

setup(trainer, module, stage)

Called when trainer is setting up.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
module LightningModule

Lightning module.

required
stage str

Current stage (fit, validate, test, or predict).

required

FCNModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

Name Type Description Default
algorithm_config AlgorithmModel or dict

Algorithm configuration.

required

Attributes:

Name Type Description
model Module

PyTorch model.

loss_func Module

Loss function.

optimizer_name str

Optimizer name.

optimizer_params dict

Optimizer parameters.

lr_scheduler_name str

Learning rate scheduler name.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

Type Description
Any

Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Any

Input tensor.

required

Returns:

Type Description
Any

Output tensor.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Model output.

training_step(batch, batch_idx)

Training step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

HyperParametersCallback

Bases: Callback

Callback allowing saving CAREamics configuration as hyperparameters in the model.

This allows saving the configuration as dictionary in the checkpoints, and loading it subsequently in a CAREamist instance.

Parameters:

Name Type Description Default
config Configuration

CAREamics configuration to be saved as hyperparameter in the model.

required

Attributes:

Name Type Description
config Configuration

CAREamics configuration to be saved as hyperparameter in the model.

on_train_start(trainer, pl_module)

Update the hyperparameters of the model with the configuration on train start.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer, unused.

required
pl_module LightningModule

PyTorch Lightning module.

required

MicroSplitDataModule

Bases: LightningDataModule

Lightning DataModule for MicroSplit-style datasets.

Matches the interface of TrainDataModule, but internally uses original MicroSplit dataset logic.

Parameters:

Name Type Description Default
data_config MicroSplitDataConfig

Configuration for the MicroSplit dataset.

required
train_data str

Path to training data directory.

required
val_data str

Path to validation data directory.

None
train_data_target str

Path to training target data.

None
val_data_target str

Path to validation target data.

None
read_source_func Callable

Function to read source data.

None
extension_filter str

File extension filter.

''
val_percentage float

Percentage of data to use for validation, by default 0.1.

0.1
val_minimum_split int

Minimum number of samples for validation split, by default 5.

5
use_in_memory bool

Whether to use in-memory dataset, by default True.

True

get_data_stats()

Get data statistics.

Returns:

Type Description
tuple[dict, dict]

A tuple containing two dictionaries: - data_mean: mean values for input and target - data_std: standard deviation values for input and target

train_dataloader()

Create a dataloader for training.

Returns:

Type Description
DataLoader

Training dataloader.

val_dataloader()

Create a dataloader for validation.

Returns:

Type Description
DataLoader

Validation dataloader.

PredictDataModule

Bases: LightningDataModule

CAREamics Lightning prediction data module.

The data module can be used with Path, str or numpy arrays. The data can be either a folder containing images or a single file.

To read custom data types, you can set data_type to custom in data_config and provide a function that returns a numpy array from a path as read_source_func parameter. The function will receive a Path object and an axies string as arguments, the axes being derived from the data_config.

You can also provide a fnmatch and Path.rglob compatible expression (e.g. "*.czi") to filter the files extension using extension_filter.

Parameters:

Name Type Description Default
pred_config InferenceModel

Pydantic model for CAREamics prediction configuration.

required
pred_data Path or str or ndarray

Prediction data, can be a path to a folder, a file or a numpy array.

required
read_source_func Callable

Function to read custom types, by default None.

None
extension_filter str

Filter to filter file extensions for custom types, by default "".

''
dataloader_params dict

Dataloader parameters, by default {}.

None

predict_dataloader()

Create a dataloader for prediction.

Returns:

Type Description
DataLoader

Prediction dataloader.

prepare_data()

Hook used to prepare the data before calling setup.

setup(stage=None)

Hook called at the beginning of predict.

Parameters:

Name Type Description Default
stage Optional[str]

Stage, by default None.

None

PredictionStoppedException

Bases: Exception

Exception raised when prediction is stopped by external signal.

ProgressBarCallback

Bases: TQDMProgressBar

Progress bar for training and validation steps.

get_metrics(trainer, pl_module)

Override this to customize the metrics displayed in the progress bar.

Parameters:

Name Type Description Default
trainer Trainer

The trainer object.

required
pl_module LightningModule

The LightningModule object, unused.

required

Returns:

Type Description
dict

A dictionary with the metrics to display in the progress bar.

init_test_tqdm()

Override this to customize the tqdm bar for testing.

Returns:

Type Description
tqdm

A tqdm bar.

init_train_tqdm()

Override this to customize the tqdm bar for training.

Returns:

Type Description
tqdm

A tqdm bar.

init_validation_tqdm()

Override this to customize the tqdm bar for validation.

Returns:

Type Description
tqdm

A tqdm bar.

StopPredictionCallback

Bases: Callback

PyTorch Lightning callback to stop prediction based on external condition.

This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.

Parameters:

Name Type Description Default
stop_condition Callable[[], bool]

A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch.

required

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

Check stop condition at the start of each prediction batch.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer instance.

required
pl_module LightningModule

Lightning module being used for prediction.

required
batch Any

Current batch of data.

required
batch_idx int

Index of the current batch.

required
dataloader_idx int

Index of the dataloader, by default 0.

0

Raises:

Type Description
PredictionStoppedException

If stop_condition() returns True.

TrainDataModule

Bases: LightningDataModule

CAREamics Ligthning training and validation data module.

The data module can be used with Path, str or numpy arrays. In the case of numpy arrays, it loads and computes all the patches in memory. For Path and str inputs, it calculates the total file size and estimate whether it can fit in memory. If it does not, it iterates through the files. This behaviour can be deactivated by setting use_in_memory to False, in which case it will always use the iterating dataset to train on a Path or str.

The data can be either a folder containing images or a single file.

Validation can be omitted, in which case the validation data is extracted from the training data. The percentage of the training data to use for validation, as well as the minimum number of patches or files to split from the training data can be set using val_percentage and val_minimum_split, respectively.

To read custom data types, you can set data_type to custom in data_config and provide a function that returns a numpy array from a path as read_source_func parameter. The function will receive a Path object and an axies string as arguments, the axes being derived from the data_config.

You can also provide a fnmatch and Path.rglob compatible expression (e.g. "*.czi") to filter the files extension using extension_filter.

Parameters:

Name Type Description Default
data_config DataModel

Pydantic model for CAREamics data configuration.

required
train_data Path or str or ndarray

Training data, can be a path to a folder, a file or a numpy array.

required
val_data Path or str or ndarray

Validation data, can be a path to a folder, a file or a numpy array, by default None.

None
train_data_target Path or str or ndarray

Training target data, can be a path to a folder, a file or a numpy array, by default None.

None
val_data_target Path or str or ndarray

Validation target data, can be a path to a folder, a file or a numpy array, by default None.

None
read_source_func Callable

Function to read the source data, by default None. Only used for custom data type (see DataModel).

None
extension_filter str

Filter for file extensions, by default "". Only used for custom data types (see DataModel).

''
val_percentage float

Percentage of the training data to use for validation, by default 0.1. Only used if val_data is None.

0.1
val_minimum_split int

Minimum number of patches or files to split from the training data for validation, by default 5. Only used if val_data is None.

5
use_in_memory bool

Use in memory dataset if possible, by default True.

True

Attributes:

Name Type Description
data_config DataModel

CAREamics data configuration.

data_type SupportedData

Expected data type, one of "tiff", "array" or "custom".

batch_size int

Batch size.

use_in_memory bool

Whether to use in memory dataset if possible.

train_data Path or ndarray

Training data.

val_data Path or ndarray

Validation data.

train_data_target Path or ndarray

Training target data.

val_data_target Path or ndarray

Validation target data.

val_percentage float

Percentage of the training data to use for validation, if no validation data is provided.

val_minimum_split int

Minimum number of patches or files to split from the training data for validation, if no validation data is provided.

read_source_func Optional[Callable]

Function to read the source data, used if data_type is custom.

extension_filter str

Filter for file extensions, used if data_type is custom.

get_data_statistics()

Return training data statistics.

Returns:

Type Description
tuple of list

Means and standard deviations across channels of the training data.

prepare_data()

Hook used to prepare the data before calling setup.

Here, we only need to examine the data if it was provided as a str or a Path.

TODO: from lightning doc: prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won't be available for other processes.

https://lightning.ai/docs/pytorch/stable/data/datamodule.html

setup(*args, **kwargs)

Hook called at the beginning of fit, validate, or predict.

Parameters:

Name Type Description Default
*args Any

Unused.

()
**kwargs Any

Unused.

{}

train_dataloader()

Create a dataloader for training.

Returns:

Type Description
Any

Training dataloader.

val_dataloader()

Create a dataloader for validation.

Returns:

Type Description
Any

Validation dataloader.

VAEModule

Bases: LightningModule

CAREamics Lightning module.

This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an AlgorithmModel Pydantic class.

Parameters:

Name Type Description Default
algorithm_config Union[VAEAlgorithmConfig, dict]

Algorithm configuration.

required

Attributes:

Name Type Description
model Module

PyTorch model.

loss_func Module

Loss function.

optimizer_name str

Optimizer name.

optimizer_params dict

Optimizer parameters.

lr_scheduler_name str

Learning rate scheduler name.

compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)

Compute the PSNR for the current validation batch.

Parameters:

Name Type Description Default
model_output tuple[Tensor, dict[str, Any]]

Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.

required
target Tensor

Target tensor.

required
psnr_func Callable

PSNR function to use, by default scale_invariant_psnr.

scale_invariant_psnr

Returns:

Type Description
list[float]

PSNR for each channel in the current batch.

configure_optimizers()

Configure optimizers and learning rate schedulers.

Returns:

Type Description
Any

Optimizer and learning rate scheduler.

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs.

required

Returns:

Type Description
tuple[Tensor, dict[str, Any]]

A tuple with the output tensor and additional data from the top-down pass.

get_reconstructed_tensor(model_outputs)

Get the reconstructed tensor from the LVAE model outputs.

Parameters:

Name Type Description Default
model_outputs tuple[Tensor, dict[str, Any]]

Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.

required

Returns:

Type Description
Tensor

Reconstructed tensor, i.e., the predicted mean.

on_validation_epoch_end()

Validation epoch end.

predict_step(batch, batch_idx)

Prediction step.

Parameters:

Name Type Description Default
batch Tensor

Input batch.

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Model output.

reduce_running_psnr()

Reduce the running PSNR statistics and reset the running PSNR.

Returns:

Type Description
Optional[float]

Running PSNR averaged over the different output channels.

set_data_stats(data_mean, data_std)

Set data mean and std for the noise model likelihood.

Parameters:

Name Type Description Default
data_mean float

Mean of the data.

required
data_std float

Standard deviation of the data.

required

training_step(batch, batch_idx)

Training step.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
batch_idx Any

Batch index.

required

Returns:

Type Description
Any

Loss value.

validation_step(batch, batch_idx)

Validation step.

Parameters:

Name Type Description Default
batch tuple[Tensor, Tensor]

Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).

required
batch_idx Any

Batch index.

required

create_careamics_module(algorithm, loss, architecture, use_n2v2=False, struct_n2v_axis='none', struct_n2v_span=5, model_parameters=None, optimizer='Adam', optimizer_parameters=None, lr_scheduler='ReduceLROnPlateau', lr_scheduler_parameters=None)

Create a CAREamics Lightning module.

This function exposes parameters used to create an AlgorithmModel instance, triggering parameters validation.

Parameters:

Name Type Description Default
algorithm SupportedAlgorithm or str

Algorithm to use for training (see SupportedAlgorithm).

required
loss SupportedLoss or str

Loss function to use for training (see SupportedLoss).

required
architecture SupportedArchitecture or str

Model architecture to use for training (see SupportedArchitecture).

required
use_n2v2 bool

Whether to use N2V2 or Noise2Void.

False
struct_n2v_axis "horizontal", "vertical", or "none"

Axis of the StructN2V mask.

"none"
struct_n2v_span int

Span of the StructN2V mask.

5
model_parameters dict

Model parameters to use for training, by default {}. Model parameters are defined in the relevant torch.nn.Module class, or Pyddantic model (see careamics.config.architectures).

None
optimizer SupportedOptimizer or str

Optimizer to use for training, by default "Adam" (see SupportedOptimizer).

'Adam'
optimizer_parameters dict

Optimizer parameters to use for training, as defined in torch.optim, by default {}.

None
lr_scheduler SupportedScheduler or str

Learning rate scheduler to use for training, by default "ReduceLROnPlateau" (see SupportedScheduler).

'ReduceLROnPlateau'
lr_scheduler_parameters dict

Learning rate scheduler parameters to use for training, as defined in torch.optim, by default {}.

None

Returns:

Type Description
CAREamicsModule

CAREamics Lightning module.

create_microsplit_predict_datamodule(pred_data, tile_size, batch_size=1, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, data_stats=None, tiling_mode=TilingMode.ShiftBoundary, read_source_func=None, extension_filter='', dataloader_params=None, **dataset_kwargs)

Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.

Parameters:

Name Type Description Default
pred_data str or Path or ndarray

Prediction data, can be a path to a folder, a file or a numpy array.

required
tile_size tuple

Size of one tile of data.

required
batch_size int

Batch size for prediction dataloader.

1
num_channels int

Number of channels in the input.

2
depth3D int

Number of slices in 3D.

1
grid_size tuple

Grid size for patch extraction.

None
multiscale_count int

Number of LC scales.

None
data_stats tuple

Data statistics, by default None.

None
tiling_mode TilingMode

Tiling mode for patch extraction.

ShiftBoundary
read_source_func Callable

Function to read the source data.

None
extension_filter str

File extension filter.

''
dataloader_params dict

Parameters for prediction dataloader.

None
**dataset_kwargs

Additional arguments passed to MicroSplitDataConfig.

{}

Returns:

Type Description
MicroSplitPredictDataModule

Configured MicroSplitPredictDataModule instance.

create_microsplit_train_datamodule(train_data, patch_size, batch_size, val_data=None, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, tiling_mode=TilingMode.ShiftBoundary, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True, transforms=None, train_dataloader_params=None, val_dataloader_params=None, **dataset_kwargs)

Create a MicroSplitDataModule for MicroSplit-style datasets.

Parameters:

Name Type Description Default
train_data str

Path to training data.

required
patch_size tuple

Size of one patch of data.

required
batch_size int

Batch size for dataloaders.

required
val_data str

Path to validation data.

None
num_channels int

Number of channels in the input.

2
depth3D int

Number of slices in 3D.

1
grid_size tuple

Grid size for patch extraction.

None
multiscale_count int

Number of LC scales.

None
tiling_mode TilingMode

Tiling mode for patch extraction.

ShiftBoundary
extension_filter str

File extension filter.

''
val_percentage float

Percentage of training data to use for validation.

0.1
val_minimum_split int

Minimum number of patches/files for validation split.

5
use_in_memory bool

Use in-memory dataset if possible.

True
transforms list

List of transforms to apply.

None
train_dataloader_params dict

Parameters for training dataloader.

None
val_dataloader_params dict

Parameters for validation dataloader.

None
**dataset_kwargs

Additional arguments passed to DatasetConfig.

{}

Returns:

Type Description
MicroSplitDataModule

Configured MicroSplitDataModule instance.

create_predict_datamodule(pred_data, data_type, axes, image_means, image_stds, tile_size=None, tile_overlap=None, batch_size=1, tta_transforms=True, read_source_func=None, extension_filter='', dataloader_params=None)

Create a CAREamics prediction Lightning datamodule.

This function is used to explicitly pass the parameters usually contained in an inference_model configuration.

Since the lightning datamodule has no access to the model, make sure that the parameters passed to the datamodule are consistent with the model's requirements and are coherent. This can be done by creating a Configuration object beforehand and passing its parameters to the different Lightning modules.

The data module can be used with Path, str or numpy arrays. To use array data, set data_type to array and pass a numpy array to train_data.

By default, CAREamics only supports types defined in careamics.config.support.SupportedData. To read custom data types, you can set data_type to custom and provide a function that returns a numpy array from a path. Additionally, pass a fnmatch and Path.rglob compatible expression (e.g. "*.jpeg") to filter the files extension using extension_filter.

In dataloader_params, you can pass any parameter accepted by PyTorch dataloaders, except for batch_size, which is set by the batch_size parameter.

Parameters:

Name Type Description Default
pred_data str or Path or ndarray

Prediction data.

required
data_type (array, tiff, custom)

Data type, see SupportedData for available options.

"array"
axes str

Axes of the data, chosen among SCZYX.

required
image_means list of float

Mean values for normalization, only used if Normalization is defined.

required
image_stds list of float

Std values for normalization, only used if Normalization is defined.

required
tile_size tuple of int

Tile size, 2D or 3D tile size.

None
tile_overlap tuple of int

Tile overlap, 2D or 3D tile overlap.

None
batch_size int

Batch size.

1
tta_transforms bool

Use test time augmentation, by default True.

True
read_source_func Callable

Function to read the source data, used if data_type is custom, by default None.

None
extension_filter str

Filter for file extensions, used if data_type is custom, by default "".

''
dataloader_params dict

Pytorch dataloader parameters, by default {}.

None

Returns:

Type Description
PredictDataModule

CAREamics prediction datamodule.

Notes

If you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. If your image has less dimensions, as it may happen in the Z dimension, consider padding your image.

create_train_datamodule(train_data, data_type, patch_size, axes, batch_size, val_data=None, transforms=None, train_target_data=None, val_target_data=None, read_source_func=None, extension_filter='', val_percentage=0.1, val_minimum_patches=5, train_dataloader_params=None, val_dataloader_params=None, use_in_memory=True)

Create a TrainDataModule.

This function is used to explicitly pass the parameters usually contained in a GenericDataConfig to a TrainDataModule.

Since the lightning datamodule has no access to the model, make sure that the parameters passed to the datamodule are consistent with the model's requirements and are coherent.

The default augmentations are XY flip and XY rotation. To use a different set of augmentations, you can pass a list of transforms to transforms.

The data module can be used with Path, str or numpy arrays. In the case of numpy arrays, it loads and computes all the patches in memory. For Path and str inputs, it calculates the total file size and estimate whether it can fit in memory. If it does not, it iterates through the files. This behaviour can be deactivated by setting use_in_memory to False, in which case it will always use the iterating dataset to train on a Path or str.

To use array data, set data_type to array and pass a numpy array to train_data.

By default, CAREamics only supports types defined in careamics.config.support.SupportedData. To read custom data types, you can set data_type to custom and provide a function that returns a numpy array from a path. Additionally, pass a fnmatch and Path.rglob compatible expression (e.g. "*.jpeg") to filter the files extension using extension_filter.

In the absence of validation data, the validation data is extracted from the training data. The percentage of the training data to use for validation, as well as the minimum number of patches to split from the training data for validation can be set using val_percentage and val_minimum_patches, respectively.

In dataloader_params, you can pass any parameter accepted by PyTorch dataloaders, except for batch_size, which is set by the batch_size parameter.

Parameters:

Name Type Description Default
train_data Path or str or ndarray

Training data.

required
data_type (array, tiff, custom)

Data type, see SupportedData for available options.

"array"
patch_size list of int

Patch size, 2D or 3D patch size.

required
axes str

Axes of the data, chosen amongst SCZYX.

required
batch_size int

Batch size.

required
val_data Path or str or ndarray

Validation data, by default None.

None
transforms list of Transforms

List of transforms to apply to training patches. If None, default transforms are applied.

None
train_target_data Path or str or ndarray

Training target data, by default None.

None
val_target_data Path or str or ndarray

Validation target data, by default None.

None
read_source_func Callable

Function to read the source data, used if data_type is custom, by default None.

None
extension_filter str

Filter for file extensions, used if data_type is custom, by default "".

''
val_percentage float

Percentage of the training data to use for validation if no validation data is given, by default 0.1.

0.1
val_minimum_patches int

Minimum number of patches to split from the training data for validation if no validation data is given, by default 5.

5
train_dataloader_params dict

Pytorch dataloader parameters for the training data, by default {}.

None
val_dataloader_params dict

Pytorch dataloader parameters for the validation data, by default {}.

None
use_in_memory bool

Use in memory dataset if possible, by default True.

True

Returns:

Type Description
TrainDataModule

CAREamics training Lightning data module.

Examples:

Create a TrainingDataModule with default transforms with a numpy array:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> my_array = np.arange(256).reshape(16, 16)
>>> data_module = create_train_datamodule(
...     train_data=my_array,
...     data_type="array",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
... )

For custom data types (those not supported by CAREamics), then one can pass a read function and a filter for the files extension:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>>
>>> def read_npy(path):
...     return np.load(path)
>>>
>>> data_module = create_train_datamodule(
...     train_data="path/to/data",
...     data_type="custom",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
...     read_source_func=read_npy,
...     extension_filter="*.npy",
... )

If you want to use a different set of augmentations, you can pass a list of transforms:

>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> from careamics.config.augmentations import XYFlipConfig
>>> from careamics.config.support import SupportedTransform
>>> my_array = np.arange(256).reshape(16, 16)
>>> my_transforms = [
...     XYFlipConfig(flip_y=False),
... ]
>>> data_module = create_train_datamodule(
...     train_data=my_array,
...     data_type="array",
...     patch_size=(8, 8),
...     axes='YX',
...     batch_size=2,
...     transforms=my_transforms,
... )