Skip to content

Eval Utils

Source

This script provides methods to evaluate the performance of the LVAE model. It includes functions to: - make predictions, - quantify the performance of the model - create plots to visualize the results.

PatchLocation

Encapsulates t_idx and spatial location.

TilingMode

Enum for the tiling mode.

add_psnr_str(ax_, psnr)

Add psnr string to the axes

clean_ax(ax)

Helper function to remove ticks from axes in plots.

get_eval_output_dir(saveplotsdir, patch_size, mmse_count=50)

Given the path to a root directory to save plots, patch size, and mmse count, it returns the specific directory to save the plots.

get_fractional_change(target, prediction, max_val=None)

Get relative difference between target and prediction.

get_location_from_idx(dset, dset_input_idx, pred_h, pred_w)

For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies. Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction. Which time frame, which spatial location (h_start, h_end, w_start,w_end) Args: dset: dset_input_idx: pred_h: pred_w:

get_predictions(model, dset, batch_size, tile_size=None, grid_size=None, mmse_count=1, num_workers=4)

Get patch-wise predictions from a model for the entire dataset.

Parameters:

Name Type Description Default
model VAEModule

Lightning model used for prediction.

required
dset Dataset

Dataset to predict on.

required
batch_size int

Batch size to use for prediction.

required
loss_type

Type of reconstruction loss used by the model, by default None.

required
mmse_count int

Number of samples to generate for each input and then to average over for MMSE estimation, by default 1.

1
num_workers int

Number of workers to use for DataLoader, by default 4.

4

Returns:

Type Description
tuple[ndarray, ndarray, ndarray, ndarray, List[float]]

Tuple containing: - predictions: Predicted images for the dataset. - predictions_std: Standard deviation of the predicted images. - logvar_arr: Log variance of the predicted images. - losses: Reconstruction losses for the predictions. - psnr: PSNR values for the predictions.

get_psnr_str(tar_hsnr, pred, col_idx)

Compute PSNR between the ground truth (tar_hsnr) and the predicted image (pred).

get_single_file_mmse(model, dset, batch_size, tile_size=None, grid_size=None, mmse_count=1, num_workers=4)

Get patch-wise predictions from a model for a single file dataset.

get_single_file_predictions(model, dset, batch_size, tile_size=None, grid_size=None, num_workers=4)

Get patch-wise predictions from a model for a single file dataset.

get_zero_centered_midval(error)

When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))

plot_calibration(ax, calibration_stats)

To plot calibration statistics (RMV vs RMSE).

plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None)

Plot the relative difference between target and prediction. NOTE: The plot is overlapped to the prediction image (in gray scale). NOTE: The colorbar is centered at 0.

shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap')

Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in- matplotlib

Function to offset the "center" of a colormap. Useful for data with a negative min and positive max and you want the middle of the colormap's dynamic range to be at zero.

Input

cmap : The matplotlib colormap to be altered start : Offset from lowest point in the colormap's range. Defaults to 0.0 (no lower offset). Should be between 0.0 and midpoint. midpoint : The new center of the colormap. Defaults to 0.5 (no shift). Should be between 0.0 and 1.0. In general, this should be 1 - vmax / (vmax + abs(vmin)) For example if your data range from -15.0 to +5.0 and you want the center of the colormap at 0.0, midpoint should be set to 1 - 5/(5 + 15)) or 0.75 stop : Offset from highest point in the colormap's range. Defaults to 1.0 (no upper offset). Should be between midpoint and 1.0.

show_for_one(idx, val_dset, highsnr_val_dset, model, calibration_stats, mmse_count=5, patch_size=256, num_samples=2, baseline_preds=None)

Given an index, it plots the input, target, reconstructed images and the difference image. Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.

stitch_predictions(predictions, dset, smoothening_pixelcount=0)

Args: smoothening_pixelcount: number of pixels which can be interpolated

stitch_predictions_general(predictions, dset)

Stitching for the dataset with multiple files of different shape.

stitch_predictions_new(predictions, dset)

Args: smoothening_pixelcount: number of pixels which can be interpolated