VAE Lightning Module
CAREamics Lightning module.
VAEModule
Bases: LightningModule
CAREamics Lightning module.
This class encapsulates the a PyTorch model along with the training, validation,
and testing logic. It is configured using an AlgorithmModel Pydantic class.
Parameters:
Attributes:
-
model(Module) –PyTorch model.
-
loss_func(Module) –Loss function.
-
optimizer_name(str) –Optimizer name.
-
optimizer_params(dict) –Optimizer parameters.
-
lr_scheduler_name(str) –Learning rate scheduler name.
__init__(algorithm_config)
compute_val_psnr(model_output, target, psnr_func=scale_invariant_psnr)
Compute the PSNR for the current validation batch.
Parameters:
-
model_output(tuple[Tensor, dict[str, Any]]) –Model output, a tuple with the predicted mean and (optionally) logvar, and the top-down data dictionary.
-
target(Tensor) –Target tensor.
-
psnr_func(Callable, default:scale_invariant_psnr) –PSNR function to use, by default
scale_invariant_psnr.
Returns:
configure_optimizers()
Configure optimizers and learning rate schedulers.
Returns:
-
Any–Optimizer and learning rate scheduler.
forward(x)
get_reconstructed_tensor(model_outputs)
Get the reconstructed tensor from the LVAE model outputs.
Parameters:
-
model_outputs(tuple[Tensor, dict[str, Any]]) –Model outputs. It is a tuple with a tensor representing the predicted mean and (optionally) logvar, and the top-down data dictionary.
Returns:
-
Tensor–Reconstructed tensor, i.e., the predicted mean.
on_validation_epoch_end()
Validation epoch end.
predict_step(batch, batch_idx)
reduce_running_psnr()
Reduce the running PSNR statistics and reset the running PSNR.
Returns:
-
Optional[float]–Running PSNR averaged over the different output channels.
set_data_stats(data_mean, data_std)
training_step(batch, batch_idx)
Training step.
Parameters:
-
batch(tuple[Tensor, Tensor]) –Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
-
batch_idx(Any) –Batch index.
Returns:
-
Any–Loss value.
validation_step(batch, batch_idx)
Validation step.
Parameters:
-
batch(tuple[Tensor, Tensor]) –Input batch. It is a tuple with the input tensor and the target tensor. The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), where C is the number of target channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
-
batch_idx(Any) –Batch index.