Skip to content

Careamist

Source

A class to train, predict and export models in CAREamics.

CAREamist

Main CAREamics class, allowing training and prediction using various algorithms.

Parameters:

Name Type Description Default
source pathlib.Path or str or CAREamics Configuration

Path to a configuration file or a trained model.

required
work_dir str

Path to working directory in which to save checkpoints and logs, by default None.

None
callbacks list of Callback

List of callbacks to use during training and prediction, by default None.

None
enable_progress_bar bool

Whether a progress bar will be displayed during training, validation and prediction.

True

Attributes:

Name Type Description
model CAREamicsModule

CAREamics model.

cfg Configuration

CAREamics configuration.

trainer Trainer

PyTorch Lightning trainer.

experiment_logger TensorBoardLogger or WandbLogger

Experiment logger, "wandb" or "tensorboard".

work_dir Path

Working directory.

train_datamodule TrainDataModule

Training datamodule.

pred_datamodule PredictDataModule

Prediction datamodule.

export_to_bmz(path_to_archive, friendly_model_name, input_array, authors, general_description, data_description, covers=None, channel_names=None, model_version='0.1.0')

Export the model to the BioImage Model Zoo format.

This method packages the current weights into a zip file that can be uploaded to the BioImage Model Zoo. The archive consists of the model weights, the model specifications and various files (inputs, outputs, README, env.yaml etc.).

path_to_archive should point to a file with a ".zip" extension.

friendly_model_name is the name used for the model in the BMZ specs and website, it should consist of letters, numbers, dashes, underscores and parentheses only.

Input array must be of the same dimensions as the axes recorded in the configuration of the CAREamist.

Parameters:

Name Type Description Default
path_to_archive Path or str

Path in which to save the model, including file name, which should end with ".zip".

required
friendly_model_name str

Name of the model as used in the BMZ specs, it should consist of letters, numbers, dashes, underscores and parentheses only.

required
input_array NDArray

Input array used to validate the model and as example.

required
authors list of dict

List of authors of the model.

required
general_description str

General description of the model used in the BMZ metadata.

required
data_description str

Description of the data the model was trained on.

required
covers list of pathlib.Path or str

Paths to the cover images.

None
channel_names list of str

Channel names.

None
model_version str

Version of the model.

"0.2.0"

get_losses()

Return data that can be used to plot train and validation loss curves.

Returns:

Type Description
dict of str: list

Dictionary containing the losses for each epoch.

predict(source, *, batch_size=1, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, tta_transforms=False, dataloader_params=None, read_source_func=None, extension_filter='', **kwargs)

predict(source: PredictDataModule) -> Union[list[NDArray], NDArray]
predict(source: Union[Path, str], *, batch_size: int = 1, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['tiff', 'custom'] | None = None, tta_transforms: bool = False, dataloader_params: dict | None = None, read_source_func: Callable | None = None, extension_filter: str = '') -> Union[list[NDArray], NDArray]
predict(source: NDArray, *, batch_size: int = 1, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array'] | None = None, tta_transforms: bool = False, dataloader_params: dict | None = None) -> Union[list[NDArray], NDArray]

Make predictions on the provided data.

Input can be a CAREamicsPredData instance, a path to a data file, or a numpy array.

If data_type, axes and tile_size are not provided, the training configuration parameters will be used, with the patch_size instead of tile_size.

Test-time augmentation (TTA) can be switched on using the tta_transforms parameter. The TTA augmentation applies all possible flip and 90 degrees rotations to the prediction input and averages the predictions. TTA augmentation should not be used if you did not train with these augmentations.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. If your image has less dimensions, as it may happen in the Z dimension, consider padding your image.

Parameters:

Name Type Description Default
source (PredictDataModule, Path, str or ndarray)

Data to predict on.

required
batch_size int

Batch size for prediction.

1
tile_size tuple of int

Size of the tiles to use for prediction.

None
tile_overlap tuple of int

Overlap between tiles, can be None.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, custom)

Type of the input data.

"array"
tta_transforms bool

Whether to apply test-time augmentation.

True
dataloader_params dict

Parameters to pass to the dataloader.

None
read_source_func Callable

Function to read the source data.

None
extension_filter str

Filter for the file extension.

""
**kwargs Any

Unused.

{}

Returns:

Type Description
list of NDArray or NDArray

Predictions made by the model.

Raises:

Type Description
ValueError

If mean and std are not provided in the configuration.

ValueError

If tile size is not divisible by 2**depth for UNet models.

ValueError

If tile overlap is not specified.

predict_to_disk(source, *, batch_size=1, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, tta_transforms=False, dataloader_params=None, read_source_func=None, extension_filter='', write_type='tiff', write_extension=None, write_func=None, write_func_kwargs=None, prediction_dir='predictions', **kwargs)

Make predictions on the provided data and save outputs to files.

The predictions will be saved in a new directory 'predictions' within the set working directory. The directory stucture within the 'predictions' directory will match that of the source directory.

The source must be from files and not arrays. The file names of the predictions will match those of the source. If there is more than one sample within a file, the samples will be saved to seperate files. The file names of samples will have the name of the corresponding source file but with the sample index appended. E.g. If the the source file name is 'images.tiff' then the first sample's prediction will be saved with the file name "image_0.tiff". Input can be a PredictDataModule instance, a path to a data file, or a numpy array.

If data_type, axes and tile_size are not provided, the training configuration parameters will be used, with the patch_size instead of tile_size.

Test-time augmentation (TTA) can be switched on using the tta_transforms parameter. The TTA augmentation applies all possible flip and 90 degrees rotations to the prediction input and averages the predictions. TTA augmentation should not be used if you did not train with these augmentations.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. If your image has less dimensions, as it may happen in the Z dimension, consider padding your image.

Parameters:

Name Type Description Default
source (PredictDataModule or Path, str)

Data to predict on.

required
batch_size int

Batch size for prediction.

1
tile_size tuple of int

Size of the tiles to use for prediction.

None
tile_overlap tuple of int

Overlap between tiles.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, custom)

Type of the input data.

"array"
tta_transforms bool

Whether to apply test-time augmentation.

True
dataloader_params dict

Parameters to pass to the dataloader.

None
read_source_func Callable

Function to read the source data.

None
extension_filter str

Filter for the file extension.

""
write_type (tiff, custom)

The data type to save as, includes custom.

"tiff"
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_func_kwargs dict of {str: any}

Additional keyword arguments to be passed to the save function.

None
prediction_dir Path | str

The path to save the prediction results to. If prediction_dir is not absolute, the directory will be assumed to be relative to the pre-set work_dir. If the directory does not exist it will be created.

"predictions"
**kwargs Any

Unused.

{}

Raises:

Type Description
ValueError

If write_type is custom and write_extension is None.

ValueError

If write_type is custom and `write_fun is None.

ValueError

If source is not str, Path or PredictDataModule

stop_training()

Stop the training loop.

train(*, datamodule=None, train_source=None, val_source=None, train_target=None, val_target=None, use_in_memory=True, val_percentage=0.1, val_minimum_split=1)

Train the model on the provided data.

If a datamodule is provided, then training will be performed using it. Alternatively, the training data can be provided as arrays or paths.

If use_in_memory is set to True, the source provided as Path or str will be loaded in memory if it fits. Otherwise, training will be performed by loading patches from the files one by one. Training on arrays is always performed in memory.

If no validation source is provided, then the validation is extracted from the training data using val_percentage and val_minimum_split. In the case of data provided as Path or str, the percentage and minimum number are applied to the number of files. For arrays, it is the number of patches.

Parameters:

Name Type Description Default
datamodule TrainDataModule

Datamodule to train on, by default None.

None
train_source Path or str or NDArray

Train source, if no datamodule is provided, by default None.

None
val_source Path or str or NDArray

Validation source, if no datamodule is provided, by default None.

None
train_target Path or str or NDArray

Train target source, if no datamodule is provided, by default None.

None
val_target Path or str or NDArray

Validation target source, if no datamodule is provided, by default None.

None
use_in_memory bool

Use in memory dataset if possible, by default True.

True
val_percentage float

Percentage of validation extracted from training data, by default 0.1.

0.1
val_minimum_split int

Minimum number of validation (patch or file) extracted from training data, by default 1.

1

Raises:

Type Description
ValueError

If both datamodule and train_source are provided.

ValueError

If sources are not of the same type (e.g. train is an array and val is a Path).

ValueError

If the training target is provided to N2V.

ValueError

If neither a datamodule nor a source is provided.