Skip to content

Lightning API#

The so-called "Lightning API" is how we refer to using the lightning modules from CAREamics in a PyTorch Ligthning pipeline. In our high-level API, these modules are hidden from users and many checks, validations, error handling, and other features are provided. However, if you want to have increased flexibility, for instance to use your own dataset, model or a different training loop, you can re-use many of CAREamics modules in your own PyTorch Lightning pipeline.

Basic Usage
import numpy as np
from careamics.lightning import (  # (1)!
    create_careamics_module,
    create_predict_datamodule,
    create_train_datamodule,
)
from careamics.prediction_utils import convert_outputs
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
)

# training data
rng = np.random.default_rng(42)
train_array = rng.integers(0, 255, (32, 32)).astype(np.float32)
val_array = rng.integers(0, 255, (32, 32)).astype(np.float32)

# create lightning module
model = create_careamics_module(  # (2)!
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
)

# create data module
data = create_train_datamodule(
    train_data=train_array,
    val_data=val_array,
    data_type="array",
    patch_size=(16, 16),
    axes="YX",
    batch_size=2,
)

# create trainer
trainer = Trainer(  # (3)!
    max_epochs=1,
    default_root_dir=mypath,
    callbacks=[
        ModelCheckpoint(  # (4)!
            dirpath=mypath / "checkpoints",
            filename="basic_usage_lightning_api",
        )
    ],
)

# train
trainer.fit(model, datamodule=data)

# predict
means, stds = data.get_data_statistics()
predict_data = create_predict_datamodule(
    pred_data=val_array,
    data_type="array",
    axes="YX",
    image_means=means,
    image_stds=stds,
    tile_size=(8, 8),  # (5)!
    tile_overlap=(2, 2),
)

# predict
predicted = trainer.predict(model, datamodule=predict_data)
predicted_stitched = convert_outputs(predicted, tiled=True)  # (6)!
  1. We provide convenience functions to create the various Lightning modules.

  2. Each convenience function will have a set of algorithms. Often, these correspond to the parameters in the CAREamics configuration. You can check the next pages for more details.

  3. As for any Lightning pipeline, you need to instantiate a Trainer.

  4. This way, you have all freedom to set your own callbacks.

  5. Our prediction Lightning data module has the possibility to break the images into overlapping tiles.

  6. If you predicted using tiled images, you need to recombine the tiles into images. We provide a general function to take care of this.

There are three types of Lightning modules in CAREamics:

  • Lightning Module
  • Training Lightning Datamodule
  • Prediction Lightning Datamodule

In the next pages, we give more details on the various parameters of the convenience functions. For the rest, refer the the PyTorch Lightning documentation.