Skip to content

prediction_outputs

Module containing functions to convert prediction outputs to desired form.

combine_batches(predictions, tiled) #

combine_batches(predictions: list[Any], tiled: Literal[True]) -> tuple[list[NDArray], list[TileInformation]]
combine_batches(predictions: list[Any], tiled: Literal[False]) -> list[NDArray]
combine_batches(predictions: list[Any], tiled: Union[bool, Literal[True], Literal[False]]) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]

If predictions are in batches, they will be combined.

TODO improve description!#

Parameters:

Name Type Description Default
predictions list

Predictions that are output from Trainer.predict.

required
tiled bool

Whether the predictions are tiled.

required

Returns:

Type Description
(list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)

Combined batches.

Source code in src/careamics/prediction_utils/prediction_outputs.py
def combine_batches(
    predictions: list[Any], tiled: bool
) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]:
    """
    If predictions are in batches, they will be combined.

    # TODO improve description!

    Parameters
    ----------
    predictions : list
        Predictions that are output from `Trainer.predict`.
    tiled : bool
        Whether the predictions are tiled.

    Returns
    -------
    (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
        Combined batches.
    """
    if tiled:
        return _combine_tiled_batches(predictions)
    else:
        return _combine_array_batches(predictions)

convert_outputs(predictions, tiled) #

Convert the Lightning trainer outputs to the desired form.

This method allows stitching back together tiled predictions.

Parameters:

Name Type Description Default
predictions list

Predictions that are output from Trainer.predict.

required
tiled bool

Whether the predictions are tiled.

required

Returns:

Type Description
list of numpy.ndarray or numpy.ndarray

list of arrays with the axes SC(Z)YX. If there is only 1 output it will not be in a list.

Source code in src/careamics/prediction_utils/prediction_outputs.py
def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
    """
    Convert the Lightning trainer outputs to the desired form.

    This method allows stitching back together tiled predictions.

    Parameters
    ----------
    predictions : list
        Predictions that are output from `Trainer.predict`.
    tiled : bool
        Whether the predictions are tiled.

    Returns
    -------
    list of numpy.ndarray or numpy.ndarray
        list of arrays with the axes SC(Z)YX. If there is only 1 output it will not
        be in a list.
    """
    if len(predictions) == 0:
        return predictions

    # this layout is to stop mypy complaining
    if tiled:
        predictions_comb = combine_batches(predictions, tiled)
        predictions_output = stitch_prediction(*predictions_comb)
    else:
        predictions_output = combine_batches(predictions, tiled)

    return predictions_output

convert_outputs_microsplit(predictions, dataset) #

Convert microsplit Lightning trainer outputs using eval_utils stitching functions.

This function processes microsplit predictions that return (tile_prediction, tile_std) tuples and stitches them back together using the same logic as get_single_file_mmse.

Parameters:

Name Type Description Default
predictions list of tuple[NDArray, NDArray]

Predictions from Lightning trainer for microsplit. Each element is a tuple of (tile_prediction, tile_std) where both are numpy arrays from predict_step.

required
dataset Dataset

The dataset object used for prediction, needed for stitching function selection and stitching process.

required

Returns:

Type Description
tuple[NDArray, NDArray]

A tuple of (stitched_predictions, stitched_stds) representing the full stitched predictions and standard deviations.

Source code in src/careamics/prediction_utils/prediction_outputs.py
def convert_outputs_microsplit(
    predictions: list[tuple[NDArray, NDArray]], dataset
) -> tuple[NDArray, NDArray]:
    """
    Convert microsplit Lightning trainer outputs using eval_utils stitching functions.

    This function processes microsplit predictions that return
    (tile_prediction, tile_std) tuples and stitches them back together using the same
    logic as get_single_file_mmse.

    Parameters
    ----------
    predictions : list of tuple[NDArray, NDArray]
        Predictions from Lightning trainer for microsplit. Each element is a tuple of
        (tile_prediction, tile_std) where both are numpy arrays from predict_step.
    dataset : Dataset
        The dataset object used for prediction, needed for stitching function selection
        and stitching process.

    Returns
    -------
    tuple[NDArray, NDArray]
        A tuple of (stitched_predictions, stitched_stds) representing the full
        stitched predictions and standard deviations.
    """
    if len(predictions) == 0:
        raise ValueError("No predictions provided")

    # Separate predictions and stds from the list of tuples
    tile_predictions = [pred for pred, _ in predictions]
    tile_stds = [std for _, std in predictions]

    # Concatenate all tiles exactly like get_single_file_mmse
    tiles_arr = np.concatenate(tile_predictions, axis=0)
    tile_stds_arr = np.concatenate(tile_stds, axis=0)

    # Apply stitching using stitch_predictions_new
    stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
    stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)

    return stitched_predictions, stitched_stds

convert_outputs_pn2v(predictions, tiled) #

Convert the Lightning trainer outputs to the desired form.

This method allows stitching back together tiled predictions.

Parameters:

Name Type Description Default
predictions list

Predictions that are output from Trainer.predict. Length of list the total number of tiles divided by the batch size. Each element consists of a tuple of ((prediction, mse), tile_info_list). 1st dimension of each tensor is the bs. Length of tile info list is the batch size.

required
tiled bool

Whether the predictions are tiled.

required

Returns:

Type Description
tuple[list[NDArray], list[NDArray]]

Tuple of (predictions, mmse) where each is a list of arrays with axes SC(Z)YX.

Source code in src/careamics/prediction_utils/prediction_outputs.py
def convert_outputs_pn2v(
    predictions: list[Any], tiled: bool
) -> tuple[list[NDArray], list[NDArray]]:
    """
    Convert the Lightning trainer outputs to the desired form.

    This method allows stitching back together tiled predictions.

    Parameters
    ----------
    predictions : list
        Predictions that are output from `Trainer.predict`. Length of list the total
        number of tiles divided by the batch size. Each element consists of a tuple of
        ((prediction, mse), tile_info_list). 1st dimension of each tensor is the bs.
        Length of tile info list is the batch size.

    tiled : bool
        Whether the predictions are tiled.

    Returns
    -------
    tuple[list[NDArray], list[NDArray]]
        Tuple of (predictions, mmse) where each is a list of arrays with axes SC(Z)YX.
    """
    if len(predictions) == 0:
        return [], []
    # TODO test with multi_channel predictions
    if tiled:
        # Separate predictions and mmse, keeping tile info for each
        pred_with_tiles = [
            (pred, tile_info_list) for (pred, _), tile_info_list in predictions
        ]
        mse_with_tiles = [
            (mse, tile_info_list) for (_, mse), tile_info_list in predictions
        ]

        # Process predictions
        pred_comb = combine_batches(pred_with_tiles, tiled)
        predictions_output = stitch_prediction(*pred_comb)

        # Process mmse
        mse_comb = combine_batches(mse_with_tiles, tiled)
        mse_output = stitch_prediction(*mse_comb)

        return predictions_output, mse_output
    else:
        # Separate predictions and mmse for non-tiled case
        pred_only_tuple, mse_only_tuple = zip(*predictions, strict=False)
        pred_only_list: list[NDArray] = list(pred_only_tuple)
        mse_only_list: list[NDArray] = list(mse_only_tuple)

        predictions_output = combine_batches(pred_only_list, tiled=False)
        mse_output = combine_batches(mse_only_list, tiled=False)

        return predictions_output, mse_output