Loss factory module.
This module contains a factory function for creating loss functions.
FCNLossParameters
dataclass
Dataclass for FCN loss.
Source code in src/careamics/losses/loss_factory.py
| @dataclass
class FCNLossParameters:
"""Dataclass for FCN loss."""
# TODO check
prediction: tensor
targets: tensor
mask: tensor
current_epoch: int
loss_weight: float
|
loss_factory(loss)
Return loss function.
Parameters:
Name | Type | Description | Default |
loss | Union[SupportedLoss, str] | | required |
Returns:
Type | Description |
Callable | |
Raises:
Type | Description |
NotImplementedError | |
Source code in src/careamics/losses/loss_factory.py
| def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
"""Return loss function.
Parameters
----------
loss : Union[SupportedLoss, str]
Requested loss.
Returns
-------
Callable
Loss function.
Raises
------
NotImplementedError
If the loss is unknown.
"""
if loss == SupportedLoss.N2V:
return n2v_loss
# elif loss_type == SupportedLoss.PN2V:
# return pn2v_loss
elif loss == SupportedLoss.MAE:
return mae_loss
elif loss == SupportedLoss.MSE:
return mse_loss
elif loss == SupportedLoss.MUSPLIT:
return musplit_loss
elif loss == SupportedLoss.DENOISPLIT:
return denoisplit_loss
elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
return denoisplit_musplit_loss
else:
raise NotImplementedError(f"Loss {loss} is not yet supported.")
|