Lightning
CAREamics PyTorch Lightning modules.
DataStatsCallback
Bases: Callback
Callback to update model's data statistics from datamodule.
This callback ensures that the model has access to the data statistics (mean, std) calculated by the datamodule before training starts.
setup(trainer, module, stage)
Called when trainer is setting up.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
module
|
LightningModule
|
Lightning module. |
required |
stage
|
str
|
Current stage (fit, validate, test, or predict). |
required |
FCNModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm_config
|
AlgorithmModel or dict
|
Algorithm configuration. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
PyTorch model. |
loss_func |
Module
|
Loss function. |
optimizer_name |
str
|
Optimizer name. |
optimizer_params |
dict
|
Optimizer parameters. |
lr_scheduler_name |
str
|
Learning rate scheduler name. |
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
| Type | Description |
|---|---|
Any
|
Optimizer and learning rate scheduler. |
forward(x)
predict_step(batch, batch_idx)
training_step(batch, batch_idx)
validation_step(batch, batch_idx)
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tensor
|
Input batch. |
required |
batch_idx
|
Any
|
Batch index. |
required |
HyperParametersCallback
Bases: Callback
Callback allowing saving CAREamics configuration as hyperparameters in the model.
This allows saving the configuration as dictionary in the checkpoints, and loading it subsequently in a CAREamist instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Configuration
|
CAREamics configuration to be saved as hyperparameter in the model. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Configuration
|
CAREamics configuration to be saved as hyperparameter in the model. |
on_train_start(trainer, pl_module)
Update the hyperparameters of the model with the configuration on train start.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer, unused. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
MicroSplitDataModule
Bases: LightningDataModule
Lightning DataModule for MicroSplit-style datasets.
Matches the interface of TrainDataModule, but internally uses original MicroSplit dataset logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_config
|
MicroSplitDataConfig
|
Configuration for the MicroSplit dataset. |
required |
train_data
|
str
|
Path to training data directory. |
required |
val_data
|
str
|
Path to validation data directory. |
None
|
train_data_target
|
str
|
Path to training target data. |
None
|
val_data_target
|
str
|
Path to validation target data. |
None
|
read_source_func
|
Callable
|
Function to read source data. |
None
|
extension_filter
|
str
|
File extension filter. |
''
|
val_percentage
|
float
|
Percentage of data to use for validation, by default 0.1. |
0.1
|
val_minimum_split
|
int
|
Minimum number of samples for validation split, by default 5. |
5
|
use_in_memory
|
bool
|
Whether to use in-memory dataset, by default True. |
True
|
get_data_stats()
train_dataloader()
Create a dataloader for training.
Returns:
| Type | Description |
|---|---|
DataLoader
|
Training dataloader. |
val_dataloader()
Create a dataloader for validation.
Returns:
| Type | Description |
|---|---|
DataLoader
|
Validation dataloader. |
PredictDataModule
Bases: LightningDataModule
CAREamics Lightning prediction data module.
The data module can be used with Path, str or numpy arrays. The data can be either a folder containing images or a single file.
To read custom data types, you can set data_type to custom in data_config
and provide a function that returns a numpy array from a path as
read_source_func parameter. The function will receive a Path object and
an axies string as arguments, the axes being derived from the data_config.
You can also provide a fnmatch and Path.rglob compatible expression (e.g.
"*.czi") to filter the files extension using extension_filter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_config
|
InferenceModel
|
Pydantic model for CAREamics prediction configuration. |
required |
pred_data
|
Path or str or ndarray
|
Prediction data, can be a path to a folder, a file or a numpy array. |
required |
read_source_func
|
Callable
|
Function to read custom types, by default None. |
None
|
extension_filter
|
str
|
Filter to filter file extensions for custom types, by default "". |
''
|
dataloader_params
|
dict
|
Dataloader parameters, by default {}. |
None
|
predict_dataloader()
Create a dataloader for prediction.
Returns:
| Type | Description |
|---|---|
DataLoader
|
Prediction dataloader. |
prepare_data()
Hook used to prepare the data before calling setup.
setup(stage=None)
Hook called at the beginning of predict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
Optional[str]
|
Stage, by default None. |
None
|
PredictionStoppedException
Bases: Exception
Exception raised when prediction is stopped by external signal.
ProgressBarCallback
Bases: TQDMProgressBar
Progress bar for training and validation steps.
get_metrics(trainer, pl_module)
Override this to customize the metrics displayed in the progress bar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
The trainer object. |
required |
pl_module
|
LightningModule
|
The LightningModule object, unused. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary with the metrics to display in the progress bar. |
init_test_tqdm()
Override this to customize the tqdm bar for testing.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
init_train_tqdm()
Override this to customize the tqdm bar for training.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
init_validation_tqdm()
Override this to customize the tqdm bar for validation.
Returns:
| Type | Description |
|---|---|
tqdm
|
A tqdm bar. |
StopPredictionCallback
Bases: Callback
PyTorch Lightning callback to stop prediction based on external condition.
This callback monitors a user-provided stop condition at the start of each prediction batch. When the condition is met, the callback stops the trainer and raises PredictionStoppedException to interrupt the prediction loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stop_condition
|
Callable[[], bool]
|
A callable that returns True when prediction should stop. The callable is invoked at the start of each prediction batch. |
required |
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)
Check stop condition at the start of each prediction batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer instance. |
required |
pl_module
|
LightningModule
|
Lightning module being used for prediction. |
required |
batch
|
Any
|
Current batch of data. |
required |
batch_idx
|
int
|
Index of the current batch. |
required |
dataloader_idx
|
int
|
Index of the dataloader, by default 0. |
0
|
Raises:
| Type | Description |
|---|---|
PredictionStoppedException
|
If stop_condition() returns True. |
TrainDataModule
Bases: LightningDataModule
CAREamics Ligthning training and validation data module.
The data module can be used with Path, str or numpy arrays. In the case of
numpy arrays, it loads and computes all the patches in memory. For Path and str
inputs, it calculates the total file size and estimate whether it can fit in
memory. If it does not, it iterates through the files. This behaviour can be
deactivated by setting use_in_memory to False, in which case it will
always use the iterating dataset to train on a Path or str.
The data can be either a folder containing images or a single file.
Validation can be omitted, in which case the validation data is extracted from
the training data. The percentage of the training data to use for validation,
as well as the minimum number of patches or files to split from the training
data can be set using val_percentage and val_minimum_split, respectively.
To read custom data types, you can set data_type to custom in data_config
and provide a function that returns a numpy array from a path as
read_source_func parameter. The function will receive a Path object and
an axies string as arguments, the axes being derived from the data_config.
You can also provide a fnmatch and Path.rglob compatible expression (e.g.
"*.czi") to filter the files extension using extension_filter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_config
|
DataModel
|
Pydantic model for CAREamics data configuration. |
required |
train_data
|
Path or str or ndarray
|
Training data, can be a path to a folder, a file or a numpy array. |
required |
val_data
|
Path or str or ndarray
|
Validation data, can be a path to a folder, a file or a numpy array, by default None. |
None
|
train_data_target
|
Path or str or ndarray
|
Training target data, can be a path to a folder, a file or a numpy array, by default None. |
None
|
val_data_target
|
Path or str or ndarray
|
Validation target data, can be a path to a folder, a file or a numpy array, by default None. |
None
|
read_source_func
|
Callable
|
Function to read the source data, by default None. Only used for |
None
|
extension_filter
|
str
|
Filter for file extensions, by default "". Only used for |
''
|
val_percentage
|
float
|
Percentage of the training data to use for validation, by default 0.1. Only
used if |
0.1
|
val_minimum_split
|
int
|
Minimum number of patches or files to split from the training data for
validation, by default 5. Only used if |
5
|
use_in_memory
|
bool
|
Use in memory dataset if possible, by default True. |
True
|
Attributes:
| Name | Type | Description |
|---|---|---|
data_config |
DataModel
|
CAREamics data configuration. |
data_type |
SupportedData
|
Expected data type, one of "tiff", "array" or "custom". |
batch_size |
int
|
Batch size. |
use_in_memory |
bool
|
Whether to use in memory dataset if possible. |
train_data |
Path or ndarray
|
Training data. |
val_data |
Path or ndarray
|
Validation data. |
train_data_target |
Path or ndarray
|
Training target data. |
val_data_target |
Path or ndarray
|
Validation target data. |
val_percentage |
float
|
Percentage of the training data to use for validation, if no validation data is provided. |
val_minimum_split |
int
|
Minimum number of patches or files to split from the training data for validation, if no validation data is provided. |
read_source_func |
Optional[Callable]
|
Function to read the source data, used if |
extension_filter |
str
|
Filter for file extensions, used if |
get_data_statistics()
Return training data statistics.
Returns:
| Type | Description |
|---|---|
tuple of list
|
Means and standard deviations across channels of the training data. |
prepare_data()
Hook used to prepare the data before calling setup.
Here, we only need to examine the data if it was provided as a str or a Path.
TODO: from lightning doc: prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won't be available for other processes.
https://lightning.ai/docs/pytorch/stable/data/datamodule.html
setup(*args, **kwargs)
train_dataloader()
val_dataloader()
VAEModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the a PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm_config
|
Union[VAEAlgorithmConfig, dict]
|
Algorithm configuration. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
PyTorch model. |
loss_func |
Module
|
Loss function. |
optimizer_name |
str
|
Optimizer name. |
optimizer_params |
dict
|
Optimizer parameters. |
lr_scheduler_name |
str
|
Learning rate scheduler name. |
compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)
Compute the PSNR for the current validation batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_output
|
tuple[Tensor, dict[str, Any]]
|
Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary. |
required |
target
|
Tensor
|
Target tensor. |
required |
psnr_func
|
Callable
|
PSNR function to use, by default |
scale_invariant_psnr
|
Returns:
| Type | Description |
|---|---|
list[float]
|
PSNR for each channel in the current batch. |
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
| Type | Description |
|---|---|
Any
|
Optimizer and learning rate scheduler. |
forward(x)
get_reconstructed_tensor(model_outputs)
Get the reconstructed tensor from the LVAE model outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_outputs
|
tuple[Tensor, dict[str, Any]]
|
Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed tensor, i.e., the predicted mean. |
on_validation_epoch_end()
Validation epoch end.
predict_step(batch, batch_idx)
reduce_running_psnr()
Reduce the running PSNR statistics and reset the running PSNR.
Returns:
| Type | Description |
|---|---|
Optional[float]
|
Running PSNR averaged over the different output channels. |
set_data_stats(data_mean, data_std)
training_step(batch, batch_idx)
Training step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
tuple[Tensor, Tensor]
|
Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit). |
required |
batch_idx
|
Any
|
Batch index. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Loss value. |
validation_step(batch, batch_idx)
Validation step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
tuple[Tensor, Tensor]
|
Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit). |
required |
batch_idx
|
Any
|
Batch index. |
required |
create_careamics_module(algorithm, loss, architecture, use_n2v2=False, struct_n2v_axis='none', struct_n2v_span=5, model_parameters=None, optimizer='Adam', optimizer_parameters=None, lr_scheduler='ReduceLROnPlateau', lr_scheduler_parameters=None)
Create a CAREamics Lightning module.
This function exposes parameters used to create an AlgorithmModel instance, triggering parameters validation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm
|
SupportedAlgorithm or str
|
Algorithm to use for training (see SupportedAlgorithm). |
required |
loss
|
SupportedLoss or str
|
Loss function to use for training (see SupportedLoss). |
required |
architecture
|
SupportedArchitecture or str
|
Model architecture to use for training (see SupportedArchitecture). |
required |
use_n2v2
|
bool
|
Whether to use N2V2 or Noise2Void. |
False
|
struct_n2v_axis
|
"horizontal", "vertical", or "none"
|
Axis of the StructN2V mask. |
"none"
|
struct_n2v_span
|
int
|
Span of the StructN2V mask. |
5
|
model_parameters
|
dict
|
Model parameters to use for training, by default {}. Model parameters are
defined in the relevant |
None
|
optimizer
|
SupportedOptimizer or str
|
Optimizer to use for training, by default "Adam" (see SupportedOptimizer). |
'Adam'
|
optimizer_parameters
|
dict
|
Optimizer parameters to use for training, as defined in |
None
|
lr_scheduler
|
SupportedScheduler or str
|
Learning rate scheduler to use for training, by default "ReduceLROnPlateau" (see SupportedScheduler). |
'ReduceLROnPlateau'
|
lr_scheduler_parameters
|
dict
|
Learning rate scheduler parameters to use for training, as defined in
|
None
|
Returns:
| Type | Description |
|---|---|
CAREamicsModule
|
CAREamics Lightning module. |
create_microsplit_predict_datamodule(pred_data, tile_size, batch_size=1, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, data_stats=None, tiling_mode=TilingMode.ShiftBoundary, read_source_func=None, extension_filter='', dataloader_params=None, **dataset_kwargs)
Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_data
|
str or Path or ndarray
|
Prediction data, can be a path to a folder, a file or a numpy array. |
required |
tile_size
|
tuple
|
Size of one tile of data. |
required |
batch_size
|
int
|
Batch size for prediction dataloader. |
1
|
num_channels
|
int
|
Number of channels in the input. |
2
|
depth3D
|
int
|
Number of slices in 3D. |
1
|
grid_size
|
tuple
|
Grid size for patch extraction. |
None
|
multiscale_count
|
int
|
Number of LC scales. |
None
|
data_stats
|
tuple
|
Data statistics, by default None. |
None
|
tiling_mode
|
TilingMode
|
Tiling mode for patch extraction. |
ShiftBoundary
|
read_source_func
|
Callable
|
Function to read the source data. |
None
|
extension_filter
|
str
|
File extension filter. |
''
|
dataloader_params
|
dict
|
Parameters for prediction dataloader. |
None
|
**dataset_kwargs
|
Additional arguments passed to MicroSplitDataConfig. |
{}
|
Returns:
| Type | Description |
|---|---|
MicroSplitPredictDataModule
|
Configured MicroSplitPredictDataModule instance. |
create_microsplit_train_datamodule(train_data, patch_size, batch_size, val_data=None, num_channels=2, depth3D=1, grid_size=None, multiscale_count=None, tiling_mode=TilingMode.ShiftBoundary, extension_filter='', val_percentage=0.1, val_minimum_split=5, use_in_memory=True, transforms=None, train_dataloader_params=None, val_dataloader_params=None, **dataset_kwargs)
Create a MicroSplitDataModule for MicroSplit-style datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_data
|
str
|
Path to training data. |
required |
patch_size
|
tuple
|
Size of one patch of data. |
required |
batch_size
|
int
|
Batch size for dataloaders. |
required |
val_data
|
str
|
Path to validation data. |
None
|
num_channels
|
int
|
Number of channels in the input. |
2
|
depth3D
|
int
|
Number of slices in 3D. |
1
|
grid_size
|
tuple
|
Grid size for patch extraction. |
None
|
multiscale_count
|
int
|
Number of LC scales. |
None
|
tiling_mode
|
TilingMode
|
Tiling mode for patch extraction. |
ShiftBoundary
|
extension_filter
|
str
|
File extension filter. |
''
|
val_percentage
|
float
|
Percentage of training data to use for validation. |
0.1
|
val_minimum_split
|
int
|
Minimum number of patches/files for validation split. |
5
|
use_in_memory
|
bool
|
Use in-memory dataset if possible. |
True
|
transforms
|
list
|
List of transforms to apply. |
None
|
train_dataloader_params
|
dict
|
Parameters for training dataloader. |
None
|
val_dataloader_params
|
dict
|
Parameters for validation dataloader. |
None
|
**dataset_kwargs
|
Additional arguments passed to DatasetConfig. |
{}
|
Returns:
| Type | Description |
|---|---|
MicroSplitDataModule
|
Configured MicroSplitDataModule instance. |
create_predict_datamodule(pred_data, data_type, axes, image_means, image_stds, tile_size=None, tile_overlap=None, batch_size=1, tta_transforms=True, read_source_func=None, extension_filter='', dataloader_params=None)
Create a CAREamics prediction Lightning datamodule.
This function is used to explicitly pass the parameters usually contained in an
inference_model configuration.
Since the lightning datamodule has no access to the model, make sure that the
parameters passed to the datamodule are consistent with the model's requirements
and are coherent. This can be done by creating a Configuration object beforehand
and passing its parameters to the different Lightning modules.
The data module can be used with Path, str or numpy arrays. To use array data, set
data_type to array and pass a numpy array to train_data.
By default, CAREamics only supports types defined in
careamics.config.support.SupportedData. To read custom data types, you can set
data_type to custom and provide a function that returns a numpy array from a
path. Additionally, pass a fnmatch and Path.rglob compatible expression
(e.g. "*.jpeg") to filter the files extension using extension_filter.
In dataloader_params, you can pass any parameter accepted by PyTorch
dataloaders, except for batch_size, which is set by the batch_size
parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred_data
|
str or Path or ndarray
|
Prediction data. |
required |
data_type
|
(array, tiff, custom)
|
Data type, see |
"array"
|
axes
|
str
|
Axes of the data, chosen among SCZYX. |
required |
image_means
|
list of float
|
Mean values for normalization, only used if Normalization is defined. |
required |
image_stds
|
list of float
|
Std values for normalization, only used if Normalization is defined. |
required |
tile_size
|
tuple of int
|
Tile size, 2D or 3D tile size. |
None
|
tile_overlap
|
tuple of int
|
Tile overlap, 2D or 3D tile overlap. |
None
|
batch_size
|
int
|
Batch size. |
1
|
tta_transforms
|
bool
|
Use test time augmentation, by default True. |
True
|
read_source_func
|
Callable
|
Function to read the source data, used if |
None
|
extension_filter
|
str
|
Filter for file extensions, used if |
''
|
dataloader_params
|
dict
|
Pytorch dataloader parameters, by default {}. |
None
|
Returns:
| Type | Description |
|---|---|
PredictDataModule
|
CAREamics prediction datamodule. |
Notes
If you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. If your image has less dimensions, as it may happen in the Z dimension, consider padding your image.
create_train_datamodule(train_data, data_type, patch_size, axes, batch_size, val_data=None, transforms=None, train_target_data=None, val_target_data=None, read_source_func=None, extension_filter='', val_percentage=0.1, val_minimum_patches=5, train_dataloader_params=None, val_dataloader_params=None, use_in_memory=True)
Create a TrainDataModule.
This function is used to explicitly pass the parameters usually contained in a
GenericDataConfig to a TrainDataModule.
Since the lightning datamodule has no access to the model, make sure that the parameters passed to the datamodule are consistent with the model's requirements and are coherent.
The default augmentations are XY flip and XY rotation. To use a different set of
augmentations, you can pass a list of transforms to transforms.
The data module can be used with Path, str or numpy arrays. In the case of
numpy arrays, it loads and computes all the patches in memory. For Path and str
inputs, it calculates the total file size and estimate whether it can fit in
memory. If it does not, it iterates through the files. This behaviour can be
deactivated by setting use_in_memory to False, in which case it will
always use the iterating dataset to train on a Path or str.
To use array data, set data_type to array and pass a numpy array to
train_data.
By default, CAREamics only supports types defined in
careamics.config.support.SupportedData. To read custom data types, you can set
data_type to custom and provide a function that returns a numpy array from a
path. Additionally, pass a fnmatch and Path.rglob compatible expression (e.g.
"*.jpeg") to filter the files extension using extension_filter.
In the absence of validation data, the validation data is extracted from the
training data. The percentage of the training data to use for validation, as well as
the minimum number of patches to split from the training data for validation can be
set using val_percentage and val_minimum_patches, respectively.
In dataloader_params, you can pass any parameter accepted by PyTorch dataloaders,
except for batch_size, which is set by the batch_size parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_data
|
Path or str or ndarray
|
Training data. |
required |
data_type
|
(array, tiff, custom)
|
Data type, see |
"array"
|
patch_size
|
list of int
|
Patch size, 2D or 3D patch size. |
required |
axes
|
str
|
Axes of the data, chosen amongst SCZYX. |
required |
batch_size
|
int
|
Batch size. |
required |
val_data
|
Path or str or ndarray
|
Validation data, by default None. |
None
|
transforms
|
list of Transforms
|
List of transforms to apply to training patches. If None, default transforms are applied. |
None
|
train_target_data
|
Path or str or ndarray
|
Training target data, by default None. |
None
|
val_target_data
|
Path or str or ndarray
|
Validation target data, by default None. |
None
|
read_source_func
|
Callable
|
Function to read the source data, used if |
None
|
extension_filter
|
str
|
Filter for file extensions, used if |
''
|
val_percentage
|
float
|
Percentage of the training data to use for validation if no validation data is given, by default 0.1. |
0.1
|
val_minimum_patches
|
int
|
Minimum number of patches to split from the training data for validation if no validation data is given, by default 5. |
5
|
train_dataloader_params
|
dict
|
Pytorch dataloader parameters for the training data, by default {}. |
None
|
val_dataloader_params
|
dict
|
Pytorch dataloader parameters for the validation data, by default {}. |
None
|
use_in_memory
|
bool
|
Use in memory dataset if possible, by default True. |
True
|
Returns:
| Type | Description |
|---|---|
TrainDataModule
|
CAREamics training Lightning data module. |
Examples:
Create a TrainingDataModule with default transforms with a numpy array:
>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> my_array = np.arange(256).reshape(16, 16)
>>> data_module = create_train_datamodule(
... train_data=my_array,
... data_type="array",
... patch_size=(8, 8),
... axes='YX',
... batch_size=2,
... )
For custom data types (those not supported by CAREamics), then one can pass a read function and a filter for the files extension:
>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>>
>>> def read_npy(path):
... return np.load(path)
>>>
>>> data_module = create_train_datamodule(
... train_data="path/to/data",
... data_type="custom",
... patch_size=(8, 8),
... axes='YX',
... batch_size=2,
... read_source_func=read_npy,
... extension_filter="*.npy",
... )
If you want to use a different set of augmentations, you can pass a list of transforms:
>>> import numpy as np
>>> from careamics.lightning import create_train_datamodule
>>> from careamics.config.augmentations import XYFlipConfig
>>> from careamics.config.support import SupportedTransform
>>> my_array = np.arange(256).reshape(16, 16)
>>> my_transforms = [
... XYFlipConfig(flip_y=False),
... ]
>>> data_module = create_train_datamodule(
... train_data=my_array,
... data_type="array",
... patch_size=(8, 8),
... axes='YX',
... batch_size=2,
... transforms=my_transforms,
... )