Skip to content

loss_model

Configuration classes for LVAE losses.

KLLossConfig #

Bases: BaseModel

KL loss configuration.

Source code in src/careamics/config/loss_model.py
class KLLossConfig(BaseModel):
    """KL loss configuration."""

    model_config = ConfigDict(validate_assignment=True, validate_default=True)

    loss_type: Literal["kl", "kl_restricted"] = "kl"
    """Type of KL divergence used as KL loss."""
    rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
    """Rescaling of the KL loss."""
    aggregation: Literal["sum", "mean"] = "mean"
    """Aggregation of the KL loss across different layers."""
    free_bits_coeff: float = 0.0
    """Free bits coefficient for the KL loss."""
    annealing: bool = False
    """Whether to apply KL loss annealing."""
    start: int = -1
    """Epoch at which KL loss annealing starts."""
    annealtime: int = 10
    """Number of epochs for which KL loss annealing is applied."""
    current_epoch: int = 0
    """Current epoch in the training loop."""

aggregation = 'mean' class-attribute instance-attribute #

Aggregation of the KL loss across different layers.

annealing = False class-attribute instance-attribute #

Whether to apply KL loss annealing.

annealtime = 10 class-attribute instance-attribute #

Number of epochs for which KL loss annealing is applied.

current_epoch = 0 class-attribute instance-attribute #

Current epoch in the training loop.

free_bits_coeff = 0.0 class-attribute instance-attribute #

Free bits coefficient for the KL loss.

loss_type = 'kl' class-attribute instance-attribute #

Type of KL divergence used as KL loss.

rescaling = 'latent_dim' class-attribute instance-attribute #

Rescaling of the KL loss.

start = -1 class-attribute instance-attribute #

Epoch at which KL loss annealing starts.

LVAELossConfig #

Bases: BaseModel

LVAE loss configuration.

Source code in src/careamics/config/loss_model.py
class LVAELossConfig(BaseModel):
    """LVAE loss configuration."""

    model_config = ConfigDict(
        validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
    )

    loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
    """Type of loss to use for LVAE."""

    reconstruction_weight: float = 1.0
    """Weight for the reconstruction loss in the total net loss
    (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
    kl_weight: float = 1.0
    """Weight for the KL loss in the total net loss.
    (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
    musplit_weight: float = 0.1
    """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
    denoisplit_weight: float = 0.9
    """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
    kl_params: KLLossConfig = KLLossConfig()
    """KL loss configuration."""

    # TODO: remove?
    non_stochastic: bool = False
    """Whether to sample latents and compute KL."""

denoisplit_weight = 0.9 class-attribute instance-attribute #

Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss).

kl_params = KLLossConfig() class-attribute instance-attribute #

KL loss configuration.

kl_weight = 1.0 class-attribute instance-attribute #

Weight for the KL loss in the total net loss. (i.e., net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss).

loss_type instance-attribute #

Type of loss to use for LVAE.

musplit_weight = 0.1 class-attribute instance-attribute #

Weight for the muSplit loss (used in the muSplit-denoiSplit loss).

non_stochastic = False class-attribute instance-attribute #

Whether to sample latents and compute KL.

reconstruction_weight = 1.0 class-attribute instance-attribute #

Weight for the reconstruction loss in the total net loss (i.e., net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss).