Skip to content

loss_factory

Loss factory module.

This module contains a factory function for creating loss functions.

FCNLossParameters dataclass #

Dataclass for FCN loss.

Source code in src/careamics/losses/loss_factory.py
@dataclass
class FCNLossParameters:
    """Dataclass for FCN loss."""

    # TODO check
    prediction: tensor
    targets: tensor
    mask: tensor
    current_epoch: int
    loss_weight: float

loss_factory(loss) #

Return loss function.

Parameters:

Name Type Description Default
loss Union[SupportedLoss, str]

Requested loss.

required

Returns:

Type Description
Callable

Loss function.

Raises:

Type Description
NotImplementedError

If the loss is unknown.

Source code in src/careamics/losses/loss_factory.py
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
    """Return loss function.

    Parameters
    ----------
    loss : Union[SupportedLoss, str]
        Requested loss.

    Returns
    -------
    Callable
        Loss function.

    Raises
    ------
    NotImplementedError
        If the loss is unknown.
    """
    if loss == SupportedLoss.N2V:
        return n2v_loss

    # elif loss_type == SupportedLoss.PN2V:
    #     return pn2v_loss

    elif loss == SupportedLoss.MAE:
        return mae_loss

    elif loss == SupportedLoss.MSE:
        return mse_loss

    elif loss == SupportedLoss.MUSPLIT:
        return musplit_loss

    elif loss == SupportedLoss.DENOISPLIT:
        return denoisplit_loss

    elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
        return denoisplit_musplit_loss

    else:
        raise NotImplementedError(f"Loss {loss} is not yet supported.")