Skip to content

Lightning API

CAREamics relies on PyTorch Lightning, and thus advanced users can use the underlying modules in their own PyTorch Lightning scripts. This is what we refer to as the Lightning API, and it is recommended for users who want more control over the training and prediction process.

Quick start

Here is an example of training Noise2Void using the Lightning API. For the configuration parameters, refer to the CAREamist API documentation.

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from careamics_portfolio import PortfolioManager

from careamics.config.factories import create_advanced_n2v_config
from careamics.lightning import (
    CareamicsDataModule,
    ConfigSaverCallback,
    N2VModule,
    convert_prediction,
)

# download example data
portfolio_manager = PortfolioManager()
files = portfolio_manager.denoising.N2V_SEM.download()
train_image = files[0]
val_image = files[1]

# create configuration
config = create_advanced_n2v_config(  # (1)!
    experiment_name="na",  # unused in LightningAPI
    data_type="tiff",
    axes="YX",
    patch_size=(64, 64),
    batch_size=16,
    num_epochs=1,
    num_workers=0,  # (2)!
)

# create lightning modules
model = N2VModule(config.algorithm_config)  # (3)!

data_module = CareamicsDataModule(  # (4)!
    data_config=config.data_config,
    train_data=train_image,
    val_data=val_image,
)

callbacks = [
    ModelCheckpoint(  # (5)!
        dirpath="checkpoints",
        filename=f"{config.experiment_name}_{{epoch:02d}}_step_{{step}}",
        **config.training_config.checkpoint_params,
    ),
    ConfigSaverCallback(
        config.version, config.experiment_name, config.training_config
    ),  # (6)!
]

trainer = Trainer(
    enable_progress_bar=True,
    callbacks=callbacks,
    **config.training_config.trainer_params,  # (7)!
)

trainer.fit(model, datamodule=data_module)  # (8)!

# create an inference data config
pred_config = config.data_config.convert_mode(  # (9)!
    new_mode="predicting",
    new_patch_size=(256, 256),
    overlap_size=(48, 48),
    new_batch_size=1,
)

inf_data_module = CareamicsDataModule(  # (10)!
    data_config=pred_config,
    pred_data=train_image,
)

# run inference
tiled_predictions = trainer.predict(model, datamodule=inf_data_module)  # (11)!

# convert list of tile predictions to stitched data
stitched_predictions, sources = convert_prediction(  # (12)!
    tiled_predictions,
    tiled=True,
)
  1. Creating a CAREamics configuration ensures that the parameters passed to the various PyTorch modules are coherent.
  2. num_workers is set here to "0" as it can often create issues on local machines (e.g. Windows or macOS). Feel free to play with it!
  3. Each class of algorithm in CAREamics has its own Lightning module (the model), here we are setting up Noise2Void, but the same can be done for CARE and N2N using the CAREModule.
  4. The Lightning Datamodule take str, Path, numpy.ndarray or list of those as input. For more details, refer to the data documentation.
  5. Careamics configuration has its own default set of parameters for the ModelCheckpoint callback (and EarlyStopping). You can either use those or set up your own.
  6. The ConfigSaverCallback callback is used to log the configuration in the checkpoints.
  7. Similarly, the configuration create training parameters configuration. You can set your own rather than reusing those.
  8. As in any Lightning script, pass the model and datamodule to the trainer and call fit to start training.
  9. Our data modules require a data configuration, we reuse part of the training data configuration, but convert it to a "predicting" mode. This gives the opportunity to change some parameters, such as passing a new_patch_size and overlap_size for tiled prediction.
  10. Create a new datamodule for inference, using the new data configuration.
  11. As for training, we predict using Lightning.
  12. Predictions are returned as a list of tiles because we used new_patch_size in the data configuration, therefore we need to stitch those tiles back together.

Predicting directly to disk

A useful feature of CAREamics that can be leveraged in the Lightning API is writing predictions directly to disk. This is achieved by adding a PredictionWriterCallback.

from careamics.lightning.callbacks.prediction import (
    PredictionWriterCallback,
)

pred_writer = PredictionWriterCallback(
    dirpath="predictions", enable_writing=False
)  # (1)!

callbacks = [
    ModelCheckpoint(
        dirpath="checkpoints",
        filename=f"{config.experiment_name}_{{epoch:02d}}_step_{{step}}",
        **config.training_config.checkpoint_params,
    ),
    ConfigSaverCallback(config.version, config.experiment_name, config.training_config),
    pred_writer,  # (2)!
]

trainer = Trainer(
    enable_progress_bar=True,
    callbacks=callbacks,
    **config.training_config.trainer_params,
)

trainer.fit(model, datamodule=data_module)

# create an inference data config
pred_config = config.data_config.convert_mode(
    new_mode="predicting",
    new_patch_size=(256, 256),
    overlap_size=(48, 48),
    new_batch_size=1,
)

inf_data_module = CareamicsDataModule(
    data_config=pred_config,
    pred_data=train_image,
)

# run inference
pred_writer.set_writing_strategy("tiff", tiled=True)  # (3)!
pred_writer.enable_writing(True)  # (4)!

tiled_predictions = trainer.predict(
    model, datamodule=inf_data_module, return_predictions=False  # (5)!
)
  1. We keep the prediction writer in memory and disable writing, in case we want to perform some prediction in memory first.
  2. Add the prediction writer callback to the list of callbacks.
  3. Once we are ready to predict to disk, we set a writing_strategy (tiff, zarr, or custom), amd whether it is tiled.
  4. We also need to turn back writing on.
  5. Finally, we disable returning the predictions.