Eval Utils
This script provides methods to evaluate the performance of the LVAE model. It includes functions to: - make predictions, - quantify the performance of the model - create plots to visualize the results.
PatchLocation
Encapsulates t_idx and spatial location.
TilingMode
Enum for the tiling mode.
add_psnr_str(ax_, psnr)
Add psnr string to the axes
clean_ax(ax)
Helper function to remove ticks from axes in plots.
get_eval_output_dir(saveplotsdir, patch_size, mmse_count=50)
Given the path to a root directory to save plots, patch size, and mmse count, it returns the specific directory to save the plots.
get_fractional_change(target, prediction, max_val=None)
Get relative difference between target and prediction.
get_location_from_idx(dset, dset_input_idx, pred_h, pred_w)
For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies. Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction. Which time frame, which spatial location (h_start, h_end, w_start,w_end) Args: dset: dset_input_idx: pred_h: pred_w:
get_predictions(model, dset, batch_size, tile_size=None, grid_size=None, mmse_count=1, num_workers=4)
Get patch-wise predictions from a model for the entire dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
VAEModule
|
Lightning model used for prediction. |
required |
dset
|
Dataset
|
Dataset to predict on. |
required |
batch_size
|
int
|
Batch size to use for prediction. |
required |
loss_type
|
Type of reconstruction loss used by the model, by default |
required | |
mmse_count
|
int
|
Number of samples to generate for each input and then to average over for MMSE estimation, by default 1. |
1
|
num_workers
|
int
|
Number of workers to use for DataLoader, by default 4. |
4
|
Returns:
| Type | Description |
|---|---|
tuple[ndarray, ndarray, ndarray, ndarray, List[float]]
|
Tuple containing: - predictions: Predicted images for the dataset. - predictions_std: Standard deviation of the predicted images. - logvar_arr: Log variance of the predicted images. - losses: Reconstruction losses for the predictions. - psnr: PSNR values for the predictions. |
get_psnr_str(tar_hsnr, pred, col_idx)
Compute PSNR between the ground truth (tar_hsnr) and the predicted image (pred).
get_single_file_mmse(model, dset, batch_size, tile_size=None, grid_size=None, mmse_count=1, num_workers=4)
Get patch-wise predictions from a model for a single file dataset.
get_single_file_predictions(model, dset, batch_size, tile_size=None, grid_size=None, num_workers=4)
Get patch-wise predictions from a model for a single file dataset.
get_zero_centered_midval(error)
When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))
plot_calibration(ax, calibration_stats)
To plot calibration statistics (RMV vs RMSE).
plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None)
Plot the relative difference between target and prediction. NOTE: The plot is overlapped to the prediction image (in gray scale). NOTE: The colorbar is centered at 0.
shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap')
Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in- matplotlib
Function to offset the "center" of a colormap. Useful for data with a negative min and positive max and you want the middle of the colormap's dynamic range to be at zero.
Input
cmap : The matplotlib colormap to be altered
start : Offset from lowest point in the colormap's range.
Defaults to 0.0 (no lower offset). Should be between
0.0 and midpoint.
midpoint : The new center of the colormap. Defaults to
0.5 (no shift). Should be between 0.0 and 1.0. In
general, this should be 1 - vmax / (vmax + abs(vmin))
For example if your data range from -15.0 to +5.0 and
you want the center of the colormap at 0.0, midpoint
should be set to 1 - 5/(5 + 15)) or 0.75
stop : Offset from highest point in the colormap's range.
Defaults to 1.0 (no upper offset). Should be between
midpoint and 1.0.
show_for_one(idx, val_dset, highsnr_val_dset, model, calibration_stats, mmse_count=5, patch_size=256, num_samples=2, baseline_preds=None)
Given an index, it plots the input, target, reconstructed images and the difference image. Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.
stitch_predictions(predictions, dset, smoothening_pixelcount=0)
Args: smoothening_pixelcount: number of pixels which can be interpolated
stitch_predictions_general(predictions, dset)
Stitching for the dataset with multiple files of different shape.
stitch_predictions_new(predictions, dset)
Args: smoothening_pixelcount: number of pixels which can be interpolated