Skip to content

plotting

Plotting utilities.

plot_noise_model_probability_distribution(noise_model, signalBinIndex, histogram, channel=None, number_of_bins=100) #

Plot probability distribution P(x|s) for a certain ground truth signal.

Predictions from both Histogram and GMM-based Noise models are displayed for comparison.

Parameters:

Name Type Description Default
noise_model GaussianMixtureNoiseModel

Trained GaussianMixtureNoiseModel.

required
signalBinIndex int

Index of signal bin. Values go from 0 to number of bins (n_bin).

required
histogram NDArray

Histogram based noise model.

required
channel Optional[str]

Channel name used for plotting. Default is None.

None
number_of_bins int

Number of bins in the resulting histogram. Default is 100.

100
Source code in src/careamics/utils/plotting.py
def plot_noise_model_probability_distribution(
    noise_model: GaussianMixtureNoiseModel,
    signalBinIndex: int,
    histogram: NDArray,
    channel: Optional[str] = None,
    number_of_bins: int = 100,
) -> None:
    """Plot probability distribution P(x|s) for a certain ground truth signal.

    Predictions from both Histogram and GMM-based
    Noise models are displayed for comparison.

    Parameters
    ----------
    noise_model : GaussianMixtureNoiseModel
        Trained GaussianMixtureNoiseModel.
    signalBinIndex : int
        Index of signal bin. Values go from 0 to number of bins (`n_bin`).
    histogram : NDArray
        Histogram based noise model.
    channel : Optional[str], optional
        Channel name used for plotting. Default is None.
    number_of_bins : int, optional
        Number of bins in the resulting histogram. Default is 100.
    """
    min_signal = noise_model.min_signal.item()
    max_signal = noise_model.max_signal.item()
    bin_size = (max_signal - min_signal) / number_of_bins

    query_signal_normalized = signalBinIndex / number_of_bins
    query_signal = query_signal_normalized * (max_signal - min_signal) + min_signal
    query_signal += bin_size / 2
    query_signal = torch.tensor(query_signal)

    query_observations = torch.arange(min_signal, max_signal, bin_size)
    query_observations += bin_size / 2

    likelihoods = noise_model.likelihood(
        observations=query_observations, signals=query_signal
    ).numpy()

    plt.figure(figsize=(12, 5))
    if channel:
        plt.suptitle(f"Noise model for channel {channel}")
    else:
        plt.suptitle("Noise model")

    plt.subplot(1, 2, 1)
    plt.xlabel("Observation Bin")
    plt.ylabel("Signal Bin")
    plt.imshow(histogram**0.25, cmap="gray")
    plt.axhline(y=signalBinIndex + 0.5, linewidth=5, color="blue", alpha=0.5)

    plt.subplot(1, 2, 2)
    plt.plot(
        query_observations,
        likelihoods,
        label="GMM : " + " signal = " + str(np.round(query_signal, 2)),
        marker=".",
        color="red",
        linewidth=2,
    )
    plt.xlabel("Observations (x) for signal s = " + str(query_signal))
    plt.ylabel("Probability Density")
    plt.title("Probability Distribution P(x|s) at signal =" + str(query_signal))
    plt.legend()