Skip to content

loss_factory

Loss factory module.

This module contains a factory function for creating loss functions.

FCNLossParameters dataclass #

Dataclass for FCN loss.

Source code in src/careamics/losses/loss_factory.py
@dataclass
class FCNLossParameters:
    """Dataclass for FCN loss."""

    # TODO check
    prediction: tensor
    targets: tensor
    mask: tensor
    current_epoch: int
    loss_weight: float

LVAELossParameters dataclass #

Dataclass for LVAE loss.

Source code in src/careamics/losses/loss_factory.py
@dataclass  # TODO why not pydantic?
class LVAELossParameters:
    """Dataclass for LVAE loss."""

    # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
    # e.g., - weights, - kl_params, ...

    noise_model_likelihood: Optional[NoiseModelLikelihood] = None
    """Noise model likelihood instance."""
    gaussian_likelihood: Optional[GaussianLikelihood] = None
    """Gaussian likelihood instance."""
    current_epoch: int = 0
    """Current epoch in the training loop."""
    reconstruction_weight: float = 1.0
    """Weight for the reconstruction loss in the total net loss
    (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
    musplit_weight: float = 0.1
    """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
    denoisplit_weight: float = 0.9
    """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
    kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
    """Type of KL divergence used as KL loss."""
    kl_weight: float = 1.0
    """Weight for the KL loss in the total net loss.
    (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
    kl_annealing: bool = False
    """Whether to apply KL loss annealing."""
    kl_start: int = -1
    """Epoch at which KL loss annealing starts."""
    kl_annealtime: int = 10
    """Number of epochs for which KL loss annealing is applied."""
    non_stochastic: bool = False
    """Whether to sample latents and compute KL."""

current_epoch: int = 0 class-attribute instance-attribute #

Current epoch in the training loop.

denoisplit_weight: float = 0.9 class-attribute instance-attribute #

Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss).

gaussian_likelihood: Optional[GaussianLikelihood] = None class-attribute instance-attribute #

Gaussian likelihood instance.

kl_annealing: bool = False class-attribute instance-attribute #

Whether to apply KL loss annealing.

kl_annealtime: int = 10 class-attribute instance-attribute #

Number of epochs for which KL loss annealing is applied.

kl_start: int = -1 class-attribute instance-attribute #

Epoch at which KL loss annealing starts.

kl_type: Literal['kl', 'kl_restricted', 'kl_spatial', 'kl_channelwise'] = 'kl' class-attribute instance-attribute #

Type of KL divergence used as KL loss.

kl_weight: float = 1.0 class-attribute instance-attribute #

Weight for the KL loss in the total net loss. (i.e., net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss).

musplit_weight: float = 0.1 class-attribute instance-attribute #

Weight for the muSplit loss (used in the muSplit-denoiSplit loss).

noise_model_likelihood: Optional[NoiseModelLikelihood] = None class-attribute instance-attribute #

Noise model likelihood instance.

non_stochastic: bool = False class-attribute instance-attribute #

Whether to sample latents and compute KL.

reconstruction_weight: float = 1.0 class-attribute instance-attribute #

Weight for the reconstruction loss in the total net loss (i.e., net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss).

loss_factory(loss) #

Return loss function.

Parameters:

Name Type Description Default
loss Union[SupportedLoss, str]

Requested loss.

required

Returns:

Type Description
Callable

Loss function.

Raises:

Type Description
NotImplementedError

If the loss is unknown.

Source code in src/careamics/losses/loss_factory.py
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
    """Return loss function.

    Parameters
    ----------
    loss : Union[SupportedLoss, str]
        Requested loss.

    Returns
    -------
    Callable
        Loss function.

    Raises
    ------
    NotImplementedError
        If the loss is unknown.
    """
    if loss == SupportedLoss.N2V:
        return n2v_loss

    # elif loss_type == SupportedLoss.PN2V:
    #     return pn2v_loss

    elif loss == SupportedLoss.MAE:
        return mae_loss

    elif loss == SupportedLoss.MSE:
        return mse_loss

    elif loss == SupportedLoss.MUSPLIT:
        return musplit_loss

    elif loss == SupportedLoss.DENOISPLIT:
        return denoisplit_loss

    elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
        return denoisplit_musplit_loss

    else:
        raise NotImplementedError(f"Loss {loss} is not yet supported.")

loss_parameters_factory(type) #

Return loss parameters.

Parameters:

Name Type Description Default
type SupportedLoss

Requested loss.

required

Returns:

Type Description
Union[FCNLossParameters, LVAELossParameters]

Loss parameters.

Raises:

Type Description
NotImplementedError

If the loss is unknown.

Source code in src/careamics/losses/loss_factory.py
def loss_parameters_factory(
    type: SupportedLoss,
) -> Union[FCNLossParameters, LVAELossParameters]:
    """Return loss parameters.

    Parameters
    ----------
    type : SupportedLoss
        Requested loss.

    Returns
    -------
    Union[FCNLossParameters, LVAELossParameters]
        Loss parameters.

    Raises
    ------
    NotImplementedError
        If the loss is unknown.
    """
    if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
        return FCNLossParameters

    elif type in [
        SupportedLoss.MUSPLIT,
        SupportedLoss.DENOISPLIT,
        SupportedLoss.DENOISPLIT_MUSPLIT,
    ]:
        return LVAELossParameters  # it returns the class, not an instance

    else:
        raise NotImplementedError(f"Loss {type} is not yet supported.")