BSD68 N2V
The BSD68 dataset was adapted from K. Zhang et al (TIP, 2017) and is composed of natural images. The noise was artificially added, allowing for quantitative comparisons with the ground truth, one of the benchmark used in many denoising publications. Here, we check the performances of Noise2Void using the Lightning API of CAREamics.
This API gives you more freedom to customize the training by using wrappers around the main elements of CAREamics: the datasets and the lightning module.
In [4]:
Copied!
# Imports necessary to execute the code
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from PIL import Image
from careamics.lightning import (
create_careamics_module,
create_predict_datamodule,
create_train_datamodule,
)
from careamics.config.support import SupportedTransform
from careamics.prediction_utils import convert_outputs
from careamics.utils.metrics import psnr
from careamics_portfolio import PortfolioManager
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# Imports necessary to execute the code from pathlib import Path import matplotlib.pyplot as plt import numpy as np import tifffile from PIL import Image from careamics.lightning import ( create_careamics_module, create_predict_datamodule, create_train_datamodule, ) from careamics.config.support import SupportedTransform from careamics.prediction_utils import convert_outputs from careamics.utils.metrics import psnr from careamics_portfolio import PortfolioManager from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint
Import the dataset¶
The dataset can be directly downloaded using the careamics-portfolio
package, which uses pooch
to download the data.
In [ ]:
Copied!
# instantiate data portfolio manage
portfolio = PortfolioManager()
# and download the data
root_path = Path("./data")
files = portfolio.denoising.N2V_BSD68.download(root_path)
# create paths for the data
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"
# instantiate data portfolio manage portfolio = PortfolioManager() # and download the data root_path = Path("./data") files = portfolio.denoising.N2V_BSD68.download(root_path) # create paths for the data data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data") train_path = data_path / "train" val_path = data_path / "val" test_path = data_path / "test" / "images" gt_path = data_path / "test" / "gt"
Visualize data¶
In [6]:
Copied!
# load training and validation image and show them side by side
single_train_image = tifffile.imread(next(iter(train_path.rglob("*.tiff"))))[0]
single_val_image = tifffile.imread(next(iter(val_path.rglob("*.tiff"))))[0]
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(single_train_image, cmap="gray")
ax[0].set_title("Training Image")
ax[1].imshow(single_val_image, cmap="gray")
ax[1].set_title("Validation Image")
# load training and validation image and show them side by side single_train_image = tifffile.imread(next(iter(train_path.rglob("*.tiff"))))[0] single_val_image = tifffile.imread(next(iter(val_path.rglob("*.tiff"))))[0] fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(single_train_image, cmap="gray") ax[0].set_title("Training Image") ax[1].imshow(single_val_image, cmap="gray") ax[1].set_title("Validation Image")
Out[6]:
Text(0.5, 1.0, 'Validation Image')
In [7]:
Copied!
model = create_careamics_module(
algorithm="n2v",
loss="n2v",
architecture="UNet",
model_parameters={"n2v2": False},
)
model = create_careamics_module( algorithm="n2v", loss="n2v", architecture="UNet", model_parameters={"n2v2": False}, )
Create the data module¶
In [10]:
Copied!
train_data_module = create_train_datamodule(
train_data=train_path,
val_data=val_path,
data_type="tiff",
patch_size=(64, 64),
axes="SYX",
batch_size=64,
transforms=[
{ # you can delete a transform here to not apply it
"name": SupportedTransform.XY_FLIP.value,
"flip_x": True, # you can set parameters
"flip_y": True,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
},
{
"name": SupportedTransform.N2V_MANIPULATE.value, # mandatory to run N2V
# here you can modify the N2V manipulate parameters
},
],
)
train_data_module = create_train_datamodule( train_data=train_path, val_data=val_path, data_type="tiff", patch_size=(64, 64), axes="SYX", batch_size=64, transforms=[ { # you can delete a transform here to not apply it "name": SupportedTransform.XY_FLIP.value, "flip_x": True, # you can set parameters "flip_y": True, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, }, { "name": SupportedTransform.N2V_MANIPULATE.value, # mandatory to run N2V # here you can modify the N2V manipulate parameters }, ], )
Create the trainer¶
Note that here we modify the prediction loop, but this will be changed in the near future.
In [ ]:
Copied!
# Create Callbacks
root = Path("bsd68_n2v")
callbacks = [
ModelCheckpoint(
dirpath=root / "checkpoints",
filename="bsd68_lightning_api",
save_last=True,
)
]
# Create a Lightning Trainer
trainer = Trainer(max_epochs=100, default_root_dir=root, callbacks=callbacks)
# Train the model
trainer.fit(model, datamodule=train_data_module)
# Create Callbacks root = Path("bsd68_n2v") callbacks = [ ModelCheckpoint( dirpath=root / "checkpoints", filename="bsd68_lightning_api", save_last=True, ) ] # Create a Lightning Trainer trainer = Trainer(max_epochs=100, default_root_dir=root, callbacks=callbacks) # Train the model trainer.fit(model, datamodule=train_data_module)
In [ ]:
Copied!
means, stds = train_data_module.get_data_statistics()
pred_data_module = create_predict_datamodule(
pred_data=test_path,
data_type="tiff",
axes="YX",
batch_size=1,
tta_transforms=True,
image_means=means,
image_stds=stds,
tile_size=(128, 128),
tile_overlap=(32, 32),
)
means, stds = train_data_module.get_data_statistics() pred_data_module = create_predict_datamodule( pred_data=test_path, data_type="tiff", axes="YX", batch_size=1, tta_transforms=True, image_means=means, image_stds=stds, tile_size=(128, 128), tile_overlap=(32, 32), )
Predict¶
In [ ]:
Copied!
# Predict
prediction = trainer.predict(model, datamodule=pred_data_module)
# Convert the outputs to the original format, mostly useful if tiling is used
prediction = convert_outputs(prediction, tiled=True)
# Predict prediction = trainer.predict(model, datamodule=pred_data_module) # Convert the outputs to the original format, mostly useful if tiling is used prediction = convert_outputs(prediction, tiled=True)
Visualize the prediction¶
In [14]:
Copied!
# Show two images
noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))]
gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]
# images to show
images = np.random.choice(range(len(noises)), 3)
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
fig.tight_layout()
for i in range(3):
pred_image = prediction[images[i]].squeeze()
psnr_noisy = psnr(gts[images[i]], noises[images[i]])
psnr_result = psnr(gts[images[i]], pred_image)
ax[i, 0].imshow(noises[images[i]], cmap="gray")
ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")
ax[i, 1].imshow(pred_image, cmap="gray")
ax[i, 1].title.set_text(f"Prediction\nPSNR: {psnr_result:.2f}")
ax[i, 2].imshow(gts[images[i]], cmap="gray")
ax[i, 2].title.set_text("Ground-truth")
# Show two images noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tiff"))] gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))] # images to show images = np.random.choice(range(len(noises)), 3) fig, ax = plt.subplots(3, 3, figsize=(15, 15)) fig.tight_layout() for i in range(3): pred_image = prediction[images[i]].squeeze() psnr_noisy = psnr(gts[images[i]], noises[images[i]]) psnr_result = psnr(gts[images[i]], pred_image) ax[i, 0].imshow(noises[images[i]], cmap="gray") ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}") ax[i, 1].imshow(pred_image, cmap="gray") ax[i, 1].title.set_text(f"Prediction\nPSNR: {psnr_result:.2f}") ax[i, 2].imshow(gts[images[i]], cmap="gray") ax[i, 2].title.set_text("Ground-truth")
Compute metrics¶
In [15]:
Copied!
psnrs = np.zeros((len(prediction), 1))
for i, (pred, gt) in enumerate(zip(prediction, gts)):
psnrs[i] = psnr(gt, pred.squeeze())
print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print("Reported PSNR: 27.71")
psnrs = np.zeros((len(prediction), 1)) for i, (pred, gt) in enumerate(zip(prediction, gts)): psnrs[i] = psnr(gt, pred.squeeze()) print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}") print("Reported PSNR: 27.71")
PSNR: 27.16 +/- 2.91 Reported PSNR: 27.71
Create cover¶
In [16]:
Copied!
# create a cover image
im_idx = 3
cv_image_noisy = noises[im_idx]
cv_image_pred = prediction[im_idx].squeeze()
# create image
cover = np.zeros((256, 256))
(height, width) = cv_image_noisy.shape
assert height > 256
assert width > 256
# normalize train and prediction
norm_noise = (cv_image_noisy - cv_image_noisy.min()) / (cv_image_noisy.max() - cv_image_noisy.min())
norm_pred = (cv_image_pred - cv_image_pred.min()) / (cv_image_pred.max() - cv_image_pred.min())
# fill in halves
cover[:, :256 // 2] = norm_noise[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2 - 256 // 2:width // 2]
cover[:, 256 // 2:] = norm_pred[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2:width // 2 + 256 // 2]
# plot the single image
plt.imshow(cover, cmap="gray")
# save the image
im = Image.fromarray(cover * 255)
im = im.convert('L')
im.save("BSD68_Noise2Void_lightning_api.jpeg")
# create a cover image im_idx = 3 cv_image_noisy = noises[im_idx] cv_image_pred = prediction[im_idx].squeeze() # create image cover = np.zeros((256, 256)) (height, width) = cv_image_noisy.shape assert height > 256 assert width > 256 # normalize train and prediction norm_noise = (cv_image_noisy - cv_image_noisy.min()) / (cv_image_noisy.max() - cv_image_noisy.min()) norm_pred = (cv_image_pred - cv_image_pred.min()) / (cv_image_pred.max() - cv_image_pred.min()) # fill in halves cover[:, :256 // 2] = norm_noise[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2 - 256 // 2:width // 2] cover[:, 256 // 2:] = norm_pred[height // 2 - 256 // 2:height // 2 + 256 // 2, width // 2:width // 2 + 256 // 2] # plot the single image plt.imshow(cover, cmap="gray") # save the image im = Image.fromarray(cover * 255) im = im.convert('L') im.save("BSD68_Noise2Void_lightning_api.jpeg")