Skip to content

n2v_module

Noise2Void Lightning Module.

N2VModule #

Bases: LightningModule

CAREamics PyTorch Lightning module for N2V algorithm.

Parameters:

Name Type Description Default
algorithm_config N2VAlgorithm or dict

Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.

required
Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
class N2VModule(L.LightningModule):
    """CAREamics PyTorch Lightning module for N2V algorithm.

    Parameters
    ----------
    algorithm_config : N2VAlgorithm or dict
        Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
        dictionary.
    """

    def __init__(self, algorithm_config: N2VAlgorithm | dict[str, Any]) -> None:
        """Instantiate N2VModule.

        Parameters
        ----------
        algorithm_config : N2VAlgorithm or dict
            Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
            dictionary.
        """
        super().__init__()

        if isinstance(algorithm_config, dict):
            config = algorithm_factory(algorithm_config)
        else:
            config = algorithm_config

        if not isinstance(config, N2VAlgorithm):
            raise TypeError("algorithm_config must be a N2VAlgorithm")

        self.save_hyperparameters({"algorithm_config": config.model_dump(mode="json")})
        self.config = config
        self.model: nn.Module = UNet(**self.config.model.model_dump())
        self.n2v_manipulate = N2VManipulateTorch(
            n2v_manipulate_config=self.config.n2v_config
        )
        self.loss_func = n2v_loss

        self.metrics = MetricCollection(PeakSignalNoiseRatio())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.

        Returns
        -------
        torch.Tensor
            Model output tensor.
        """
        return self.model(x)

    def training_step(
        self,
        batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
        batch_idx: int,
    ) -> torch.Tensor:
        """Training step for N2V model.

        Parameters
        ----------
        batch : ImageRegionData or (ImageRegionData, ImageRegionData)
            A tuple containing the input data and the target data.
        batch_idx : int
            The index of the current batch in the training loop.

        Returns
        -------
        torch.Tensor
            The loss value for the current training step.
        """
        x = batch[0]
        x_data = cast(torch.Tensor, x.data)
        x_masked, x_original, mask = self.n2v_manipulate(x_data)
        prediction = self.model(x_masked)
        loss = self.loss_func(prediction, x_original, mask)

        log_training_stats(self, loss, batch_size=x_data.shape[0])

        return loss

    def validation_step(
        self,
        batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
        batch_idx: int,
    ) -> None:
        """Validation step for N2V model.

        Parameters
        ----------
        batch : ImageRegionData or (ImageRegionData, ImageRegionData)
            A tuple containing the input data and the target data.
        batch_idx : int
            The index of the current batch in the validation loop.
        """
        x = batch[0]
        x_data = cast(torch.Tensor, x.data)
        x_masked, x_original, mask = self.n2v_manipulate(x_data)
        prediction = self.model(x_masked)
        val_loss = self.loss_func(prediction, x_original, mask)
        self.metrics(prediction, x_original)
        log_validation_stats(
            self, val_loss, batch_size=x_data.shape[0], metrics=self.metrics
        )

    def predict_step(
        self,
        batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
        batch_idx: int,
    ) -> ImageRegionData:
        """Prediction step for N2V model.

        Parameters
        ----------
        batch : ImageRegionData or (ImageRegionData, ImageRegionData)
            A tuple containing the input data and optionally the target data.
        batch_idx : int
            The index of the current batch in the prediction loop.

        Returns
        -------
        ImageRegionData
            The output batch containing the predictions.
        """
        x = batch[0]
        x_data = cast(torch.Tensor, x.data)
        # TODO: add TTA
        prediction = self.model(x_data)

        normalization = self._trainer.datamodule.predict_dataset.normalization  # type: ignore[union-attr]
        denormalized_output = normalization.denormalize(prediction).cpu().numpy()

        output_batch = ImageRegionData(
            data=denormalized_output,
            source=x.source,
            data_shape=x.data_shape,
            dtype=x.dtype,
            axes=x.axes,
            region_spec=x.region_spec,
            additional_metadata={},
        )
        return output_batch

    def configure_optimizers(self) -> dict[str, Any]:  # type: ignore[override]
        """Configure optimizer and learning rate scheduler.

        Returns
        -------
        dict[str, Any]
            A dictionary containing the optimizer and learning rate scheduler.
        """
        return configure_optimizers(
            model=self.model,
            optimizer_name=self.config.optimizer.name,
            optimizer_parameters=self.config.optimizer.parameters,
            lr_scheduler_name=self.config.lr_scheduler.name,
            lr_scheduler_parameters=self.config.lr_scheduler.parameters,
        )

__init__(algorithm_config) #

Instantiate N2VModule.

Parameters:

Name Type Description Default
algorithm_config N2VAlgorithm or dict

Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a dictionary.

required
Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def __init__(self, algorithm_config: N2VAlgorithm | dict[str, Any]) -> None:
    """Instantiate N2VModule.

    Parameters
    ----------
    algorithm_config : N2VAlgorithm or dict
        Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
        dictionary.
    """
    super().__init__()

    if isinstance(algorithm_config, dict):
        config = algorithm_factory(algorithm_config)
    else:
        config = algorithm_config

    if not isinstance(config, N2VAlgorithm):
        raise TypeError("algorithm_config must be a N2VAlgorithm")

    self.save_hyperparameters({"algorithm_config": config.model_dump(mode="json")})
    self.config = config
    self.model: nn.Module = UNet(**self.config.model.model_dump())
    self.n2v_manipulate = N2VManipulateTorch(
        n2v_manipulate_config=self.config.n2v_config
    )
    self.loss_func = n2v_loss

    self.metrics = MetricCollection(PeakSignalNoiseRatio())

configure_optimizers() #

Configure optimizer and learning rate scheduler.

Returns:

Type Description
dict[str, Any]

A dictionary containing the optimizer and learning rate scheduler.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def configure_optimizers(self) -> dict[str, Any]:  # type: ignore[override]
    """Configure optimizer and learning rate scheduler.

    Returns
    -------
    dict[str, Any]
        A dictionary containing the optimizer and learning rate scheduler.
    """
    return configure_optimizers(
        model=self.model,
        optimizer_name=self.config.optimizer.name,
        optimizer_parameters=self.config.optimizer.parameters,
        lr_scheduler_name=self.config.lr_scheduler.name,
        lr_scheduler_parameters=self.config.lr_scheduler.parameters,
    )

forward(x) #

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Model output tensor.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.

    Returns
    -------
    torch.Tensor
        Model output tensor.
    """
    return self.model(x)

predict_step(batch, batch_idx) #

Prediction step for N2V model.

Parameters:

Name Type Description Default
batch ImageRegionData or (ImageRegionData, ImageRegionData)

A tuple containing the input data and optionally the target data.

required
batch_idx int

The index of the current batch in the prediction loop.

required

Returns:

Type Description
ImageRegionData

The output batch containing the predictions.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def predict_step(
    self,
    batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
    batch_idx: int,
) -> ImageRegionData:
    """Prediction step for N2V model.

    Parameters
    ----------
    batch : ImageRegionData or (ImageRegionData, ImageRegionData)
        A tuple containing the input data and optionally the target data.
    batch_idx : int
        The index of the current batch in the prediction loop.

    Returns
    -------
    ImageRegionData
        The output batch containing the predictions.
    """
    x = batch[0]
    x_data = cast(torch.Tensor, x.data)
    # TODO: add TTA
    prediction = self.model(x_data)

    normalization = self._trainer.datamodule.predict_dataset.normalization  # type: ignore[union-attr]
    denormalized_output = normalization.denormalize(prediction).cpu().numpy()

    output_batch = ImageRegionData(
        data=denormalized_output,
        source=x.source,
        data_shape=x.data_shape,
        dtype=x.dtype,
        axes=x.axes,
        region_spec=x.region_spec,
        additional_metadata={},
    )
    return output_batch

training_step(batch, batch_idx) #

Training step for N2V model.

Parameters:

Name Type Description Default
batch ImageRegionData or (ImageRegionData, ImageRegionData)

A tuple containing the input data and the target data.

required
batch_idx int

The index of the current batch in the training loop.

required

Returns:

Type Description
Tensor

The loss value for the current training step.

Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def training_step(
    self,
    batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
    batch_idx: int,
) -> torch.Tensor:
    """Training step for N2V model.

    Parameters
    ----------
    batch : ImageRegionData or (ImageRegionData, ImageRegionData)
        A tuple containing the input data and the target data.
    batch_idx : int
        The index of the current batch in the training loop.

    Returns
    -------
    torch.Tensor
        The loss value for the current training step.
    """
    x = batch[0]
    x_data = cast(torch.Tensor, x.data)
    x_masked, x_original, mask = self.n2v_manipulate(x_data)
    prediction = self.model(x_masked)
    loss = self.loss_func(prediction, x_original, mask)

    log_training_stats(self, loss, batch_size=x_data.shape[0])

    return loss

validation_step(batch, batch_idx) #

Validation step for N2V model.

Parameters:

Name Type Description Default
batch ImageRegionData or (ImageRegionData, ImageRegionData)

A tuple containing the input data and the target data.

required
batch_idx int

The index of the current batch in the validation loop.

required
Source code in src/careamics/lightning/dataset_ng/lightning_modules/n2v_module.py
def validation_step(
    self,
    batch: tuple[ImageRegionData] | tuple[ImageRegionData, ImageRegionData],
    batch_idx: int,
) -> None:
    """Validation step for N2V model.

    Parameters
    ----------
    batch : ImageRegionData or (ImageRegionData, ImageRegionData)
        A tuple containing the input data and the target data.
    batch_idx : int
        The index of the current batch in the validation loop.
    """
    x = batch[0]
    x_data = cast(torch.Tensor, x.data)
    x_masked, x_original, mask = self.n2v_manipulate(x_data)
    prediction = self.model(x_masked)
    val_loss = self.loss_func(prediction, x_original, mask)
    self.metrics(prediction, x_original)
    log_validation_stats(
        self, val_loss, batch_size=x_data.shape[0], metrics=self.metrics
    )