Skip to content

progress_bar_callback

Progressbar callback.

ProgressBarCallback #

Bases: TQDMProgressBar

Progress bar for training and validation steps.

Source code in src/careamics/lightning/callbacks/progress_bar_callback.py
class ProgressBarCallback(TQDMProgressBar):
    """Progress bar for training and validation steps."""

    def init_train_tqdm(self) -> tqdm:
        """Override this to customize the tqdm bar for training.

        Returns
        -------
        tqdm
            A tqdm bar.
        """
        bar = tqdm(
            desc="Training",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=True,
            dynamic_ncols=True,
            file=sys.stdout,
            smoothing=0,
        )
        return bar

    def init_validation_tqdm(self) -> tqdm:
        """Override this to customize the tqdm bar for validation.

        Returns
        -------
        tqdm
            A tqdm bar.
        """
        # The main progress bar doesn't exist in `trainer.validate()`
        has_main_bar = self.train_progress_bar is not None
        bar = tqdm(
            desc="Validating",
            position=(2 * self.process_position + has_main_bar),
            disable=self.is_disabled,
            leave=False,
            dynamic_ncols=True,
            file=sys.stdout,
        )
        return bar

    def init_test_tqdm(self) -> tqdm:
        """Override this to customize the tqdm bar for testing.

        Returns
        -------
        tqdm
            A tqdm bar.
        """
        bar = tqdm(
            desc="Testing",
            position=(2 * self.process_position),
            disable=self.is_disabled,
            leave=True,
            dynamic_ncols=False,
            ncols=100,
            file=sys.stdout,
        )
        return bar

    def get_metrics(
        self, trainer: Trainer, pl_module: LightningModule
    ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
        """Override this to customize the metrics displayed in the progress bar.

        Parameters
        ----------
        trainer : Trainer
            The trainer object.
        pl_module : LightningModule
            The LightningModule object, unused.

        Returns
        -------
        dict
            A dictionary with the metrics to display in the progress bar.
        """
        pbar_metrics = trainer.progress_bar_metrics
        return {**pbar_metrics}

get_metrics(trainer, pl_module) #

Override this to customize the metrics displayed in the progress bar.

Parameters:

Name Type Description Default
trainer Trainer

The trainer object.

required
pl_module LightningModule

The LightningModule object, unused.

required

Returns:

Type Description
dict

A dictionary with the metrics to display in the progress bar.

Source code in src/careamics/lightning/callbacks/progress_bar_callback.py
def get_metrics(
    self, trainer: Trainer, pl_module: LightningModule
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
    """Override this to customize the metrics displayed in the progress bar.

    Parameters
    ----------
    trainer : Trainer
        The trainer object.
    pl_module : LightningModule
        The LightningModule object, unused.

    Returns
    -------
    dict
        A dictionary with the metrics to display in the progress bar.
    """
    pbar_metrics = trainer.progress_bar_metrics
    return {**pbar_metrics}

init_test_tqdm() #

Override this to customize the tqdm bar for testing.

Returns:

Type Description
tqdm

A tqdm bar.

Source code in src/careamics/lightning/callbacks/progress_bar_callback.py
def init_test_tqdm(self) -> tqdm:
    """Override this to customize the tqdm bar for testing.

    Returns
    -------
    tqdm
        A tqdm bar.
    """
    bar = tqdm(
        desc="Testing",
        position=(2 * self.process_position),
        disable=self.is_disabled,
        leave=True,
        dynamic_ncols=False,
        ncols=100,
        file=sys.stdout,
    )
    return bar

init_train_tqdm() #

Override this to customize the tqdm bar for training.

Returns:

Type Description
tqdm

A tqdm bar.

Source code in src/careamics/lightning/callbacks/progress_bar_callback.py
def init_train_tqdm(self) -> tqdm:
    """Override this to customize the tqdm bar for training.

    Returns
    -------
    tqdm
        A tqdm bar.
    """
    bar = tqdm(
        desc="Training",
        position=(2 * self.process_position),
        disable=self.is_disabled,
        leave=True,
        dynamic_ncols=True,
        file=sys.stdout,
        smoothing=0,
    )
    return bar

init_validation_tqdm() #

Override this to customize the tqdm bar for validation.

Returns:

Type Description
tqdm

A tqdm bar.

Source code in src/careamics/lightning/callbacks/progress_bar_callback.py
def init_validation_tqdm(self) -> tqdm:
    """Override this to customize the tqdm bar for validation.

    Returns
    -------
    tqdm
        A tqdm bar.
    """
    # The main progress bar doesn't exist in `trainer.validate()`
    has_main_bar = self.train_progress_bar is not None
    bar = tqdm(
        desc="Validating",
        position=(2 * self.process_position + has_main_bar),
        disable=self.is_disabled,
        leave=False,
        dynamic_ncols=True,
        file=sys.stdout,
    )
    return bar