Lightning API
CAREamics relies on PyTorch Lightning, and thus advanced users can use the underlying modules in their own PyTorch Lightning scripts. This is what we refer to as the Lightning API, and it is recommended for users who want more control over the training and prediction process.
Quick start
Here is an example of training Noise2Void using the Lightning API. For the configuration parameters, refer to the CAREamist API documentation.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from careamics_portfolio import PortfolioManager
from careamics.config.factories import create_advanced_n2v_config
from careamics.lightning import (
CareamicsDataModule,
ConfigSaverCallback,
N2VModule,
convert_prediction,
)
# download example data
portfolio_manager = PortfolioManager()
files = portfolio_manager.denoising.N2V_SEM.download()
train_image = files[0]
val_image = files[1]
# create configuration
config = create_advanced_n2v_config( # (1)!
experiment_name="na", # unused in LightningAPI
data_type="tiff",
axes="YX",
patch_size=(64, 64),
batch_size=16,
num_epochs=1,
num_workers=0, # (2)!
)
# create lightning modules
model = N2VModule(config.algorithm_config) # (3)!
data_module = CareamicsDataModule( # (4)!
data_config=config.data_config,
train_data=train_image,
val_data=val_image,
)
callbacks = [
ModelCheckpoint( # (5)!
dirpath="checkpoints",
filename=f"{config.experiment_name}_{{epoch:02d}}_step_{{step}}",
**config.training_config.checkpoint_params,
),
ConfigSaverCallback(
config.version, config.experiment_name, config.training_config
), # (6)!
]
trainer = Trainer(
enable_progress_bar=True,
callbacks=callbacks,
**config.training_config.trainer_params, # (7)!
)
trainer.fit(model, datamodule=data_module) # (8)!
# create an inference data config
pred_config = config.data_config.convert_mode( # (9)!
new_mode="predicting",
new_patch_size=(256, 256),
overlap_size=(48, 48),
new_batch_size=1,
)
inf_data_module = CareamicsDataModule( # (10)!
data_config=pred_config,
pred_data=train_image,
)
# run inference
tiled_predictions = trainer.predict(model, datamodule=inf_data_module) # (11)!
# convert list of tile predictions to stitched data
stitched_predictions, sources = convert_prediction( # (12)!
tiled_predictions,
tiled=True,
)
- Creating a CAREamics configuration ensures that the parameters passed to the various PyTorch modules are coherent.
num_workersis set here to "0" as it can often create issues on local machines (e.g. Windows or macOS). Feel free to play with it!- Each class of algorithm in CAREamics has its own Lightning module (the model), here we
are setting up Noise2Void, but the same can be done for CARE and N2N using the
CAREModule. - The Lightning Datamodule take
str,Path,numpy.ndarrayor list of those as input. For more details, refer to the data documentation. - Careamics configuration has its own default set of parameters for the
ModelCheckpointcallback (andEarlyStopping). You can either use those or set up your own. - The
ConfigSaverCallbackcallback is used to log the configuration in the checkpoints. - Similarly, the configuration create training parameters configuration. You can set your own rather than reusing those.
- As in any Lightning script, pass the model and datamodule to the trainer and call
fitto start training. - Our data modules require a data configuration, we reuse part of the training data
configuration, but convert it to a "predicting" mode. This gives the opportunity to change
some parameters, such as passing a
new_patch_sizeandoverlap_sizefor tiled prediction. - Create a new datamodule for inference, using the new data configuration.
- As for training, we predict using Lightning.
- Predictions are returned as a list of tiles because we used
new_patch_sizein the data configuration, therefore we need to stitch those tiles back together.
Predicting directly to disk
A useful feature of CAREamics that can be leveraged in the Lightning API is writing
predictions directly to disk. This is achieved by adding a PredictionWriterCallback.
from careamics.lightning.callbacks.prediction import (
PredictionWriterCallback,
)
pred_writer = PredictionWriterCallback(
dirpath="predictions", enable_writing=False
) # (1)!
callbacks = [
ModelCheckpoint(
dirpath="checkpoints",
filename=f"{config.experiment_name}_{{epoch:02d}}_step_{{step}}",
**config.training_config.checkpoint_params,
),
ConfigSaverCallback(config.version, config.experiment_name, config.training_config),
pred_writer, # (2)!
]
trainer = Trainer(
enable_progress_bar=True,
callbacks=callbacks,
**config.training_config.trainer_params,
)
trainer.fit(model, datamodule=data_module)
# create an inference data config
pred_config = config.data_config.convert_mode(
new_mode="predicting",
new_patch_size=(256, 256),
overlap_size=(48, 48),
new_batch_size=1,
)
inf_data_module = CareamicsDataModule(
data_config=pred_config,
pred_data=train_image,
)
# run inference
pred_writer.set_writing_strategy("tiff", tiled=True) # (3)!
pred_writer.enable_writing(True) # (4)!
tiled_predictions = trainer.predict(
model, datamodule=inf_data_module, return_predictions=False # (5)!
)
- We keep the prediction writer in memory and disable writing, in case we want to perform some prediction in memory first.
- Add the prediction writer callback to the list of callbacks.
- Once we are ready to predict to disk, we set a
writing_strategy(tiff,zarr, orcustom), amd whether it is tiled. - We also need to turn back writing on.
- Finally, we disable returning the predictions.