Skip to content

train_lvae

This script is meant to load data, initialize the model, and provide the logic for training it.

get_mean_std_dict_for_model(config, train_dset) #

Computes the mean and std for the model. This will be subsequently passed to the model.

Source code in src/careamics/lvae_training/train_utils.py
def get_mean_std_dict_for_model(config, train_dset):
    """
    Computes the mean and std for the model. This will be subsequently passed to the model.
    """
    mean_dict, std_dict = train_dset.get_mean_std()

    return deepcopy(mean_dict), deepcopy(std_dict)

get_new_model_version(model_dir) #

A model will have multiple runs. Each run will have a different version.

Source code in src/careamics/lvae_training/train_utils.py
def get_new_model_version(model_dir: str) -> str:
    """
    A model will have multiple runs. Each run will have a different version.
    """
    versions = []
    for version_dir in os.listdir(model_dir):
        try:
            versions.append(int(version_dir))
        except:
            print(
                f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed"
            )
            exit()
    if len(versions) == 0:
        return "0"
    return f"{max(versions) + 1}"