Skip to content

tbplot_widget

A widget displaying losses and a button to open TensorBoard in the browser.

TBPlotWidget #

Bases: Container

A widget displaying losses and a button to open TensorBoard in the browser.

Parameters:

Name Type Description Default
min_width int or None

Minimum width of the widget.

None
min_height int or None

Minimum height of the widget.

None
max_width int or None

Maximum width of the widget.

None
max_height int or None

Maximum height of the widget.

None
train_signal TrainingSignal or None

Signal containing training parameters.

None
Source code in src/careamics_napari/widgets/tbplot_widget.py
class TBPlotWidget(Container):
    """A widget displaying losses and a button to open TensorBoard in the browser.

    Parameters
    ----------
    min_width : int or None, default=None
        Minimum width of the widget.
    min_height : int or None, default=None
        Minimum height of the widget.
    max_width : int or None, default=None
        Maximum width of the widget.
    max_height : int or None, default=None
        Maximum height of the widget.
    train_signal : TrainingSignal or None, default=None
        Signal containing training parameters.
    """

    # TODO what is this method used for?
    def __setitem__(self: Self, key: Any, value: Any) -> None:
        """Ignore set item.

        Parameters
        ----------
        key : Any
            Ignored.
        value : Any
            Ignored.
        """
        pass

    def __init__(
        self: Self,
        min_width: Optional[int] = None,
        min_height: Optional[int] = None,
        max_width: Optional[int] = None,
        max_height: Optional[int] = None,
        train_signal: Optional[TrainingSignal] = None,
    ):
        """Initialize the widget.

        Parameters
        ----------
        min_width : int or None, default=None
            Minimum width of the widget.
        min_height : int or None, default=None
            Minimum height of the widget.
        max_width : int or None, default=None
            Maximum width of the widget.
        max_height : int or None, default=None
            Maximum height of the widget.
        train_signal : TrainingSignal or None, default=None
            Signal containing training parameters.
        """
        super().__init__()

        self.train_signal = train_signal

        if max_width:
            self.native.setMaximumWidth(max_width)
        if max_height:
            self.native.setMaximumHeight(max_height)
        if min_width:
            self.native.setMinimumWidth(min_width)
        if min_height:
            self.native.setMinimumHeight(min_height)

        self.graphics_widget = pg.GraphicsLayoutWidget()
        self.graphics_widget.setBackground(None)
        self.native.layout().addWidget(self.graphics_widget)

        # plot widget
        self.plot = self.graphics_widget.addPlot()
        self.plot.setLabel("bottom", "epoch")
        self.plot.setLabel("left", "loss")
        self.plot.addLegend(offset=(125, -50))

        # tensorboard button
        tb_button = QPushButton("Open in TensorBoard")
        tb_button.setToolTip("Open TensorBoard in your browser")
        tb_button.setIcon(QIcon(QPixmap(ICON_TF)))
        tb_button.setLayoutDirection(Qt.LeftToRight)
        tb_button.setIconSize(QSize(32, 29))
        tb_button.setCursor(QCursor(Qt.CursorShape.PointingHandCursor))
        tb_button.clicked.connect(self.open_tb)

        # add to layout on the bottom left
        button_widget = QWidget()
        button_widget.setLayout(QHBoxLayout())
        button_widget.layout().addWidget(tb_button)
        button_widget.layout().addWidget(QLabel(""))
        self.native.layout().addWidget(button_widget)

        # set empty references
        self.epochs: list[int] = []
        self.train_loss: list[float] = []
        self.val_loss: list[float] = []
        self.url: Optional[str] = None
        self.tb = None

    def stop_tb(self: Self) -> None:
        """Stop the TensorBoard process.

        Currently not implemented.
        """
        # haven't found any good way to stop the tb process, there's currently no API
        # for it
        pass

    def open_tb(self: Self) -> None:
        """Open TensorBoard in the browser."""
        if self.tb is None and self.train_signal is not None:
            from tensorboard import program

            self.tb = program.TensorBoard()

            path = str(self.train_signal.work_dir / "logs" / "lightning_logs")
            self.tb.configure(argv=[None, "--logdir", path])  # type: ignore
            self.url = self.tb.launch()  # type: ignore

            if self.url is not None:
                webbrowser.open(self.url)
        else:
            if self.url is not None:
                webbrowser.open(self.url)

    def update_plot(self: Self, epoch: int, train_loss: float, val_loss: float) -> None:
        """Update the plot with new data.

        Parameters
        ----------
        epoch : int
            Epoch number.
        train_loss : float
            Training loss.
        val_loss : float
            Validation loss.
        """
        # clear the plot
        self.plot.clear()

        # add the new points
        self.epochs.append(epoch)
        self.train_loss.append(train_loss)
        self.val_loss.append(val_loss)

        # replot
        self.plot.plot(
            self.epochs,
            self.train_loss,
            pen=pg.mkPen(color=(204, 221, 255)),
            symbol="o",
            symbolSize=2,
            name="Train",
        )
        self.plot.plot(
            self.epochs,
            self.val_loss,
            pen=pg.mkPen(color=(244, 173, 173)),
            symbol="o",
            symbolSize=2,
            name="Val",
        )

    def clear_plot(self: Self) -> None:
        """Clear the plot."""
        self.plot.clear()
        self.epochs = []
        self.train_loss = []
        self.val_loss = []

__init__(min_width=None, min_height=None, max_width=None, max_height=None, train_signal=None) #

Initialize the widget.

Parameters:

Name Type Description Default
min_width int or None

Minimum width of the widget.

None
min_height int or None

Minimum height of the widget.

None
max_width int or None

Maximum width of the widget.

None
max_height int or None

Maximum height of the widget.

None
train_signal TrainingSignal or None

Signal containing training parameters.

None
Source code in src/careamics_napari/widgets/tbplot_widget.py
def __init__(
    self: Self,
    min_width: Optional[int] = None,
    min_height: Optional[int] = None,
    max_width: Optional[int] = None,
    max_height: Optional[int] = None,
    train_signal: Optional[TrainingSignal] = None,
):
    """Initialize the widget.

    Parameters
    ----------
    min_width : int or None, default=None
        Minimum width of the widget.
    min_height : int or None, default=None
        Minimum height of the widget.
    max_width : int or None, default=None
        Maximum width of the widget.
    max_height : int or None, default=None
        Maximum height of the widget.
    train_signal : TrainingSignal or None, default=None
        Signal containing training parameters.
    """
    super().__init__()

    self.train_signal = train_signal

    if max_width:
        self.native.setMaximumWidth(max_width)
    if max_height:
        self.native.setMaximumHeight(max_height)
    if min_width:
        self.native.setMinimumWidth(min_width)
    if min_height:
        self.native.setMinimumHeight(min_height)

    self.graphics_widget = pg.GraphicsLayoutWidget()
    self.graphics_widget.setBackground(None)
    self.native.layout().addWidget(self.graphics_widget)

    # plot widget
    self.plot = self.graphics_widget.addPlot()
    self.plot.setLabel("bottom", "epoch")
    self.plot.setLabel("left", "loss")
    self.plot.addLegend(offset=(125, -50))

    # tensorboard button
    tb_button = QPushButton("Open in TensorBoard")
    tb_button.setToolTip("Open TensorBoard in your browser")
    tb_button.setIcon(QIcon(QPixmap(ICON_TF)))
    tb_button.setLayoutDirection(Qt.LeftToRight)
    tb_button.setIconSize(QSize(32, 29))
    tb_button.setCursor(QCursor(Qt.CursorShape.PointingHandCursor))
    tb_button.clicked.connect(self.open_tb)

    # add to layout on the bottom left
    button_widget = QWidget()
    button_widget.setLayout(QHBoxLayout())
    button_widget.layout().addWidget(tb_button)
    button_widget.layout().addWidget(QLabel(""))
    self.native.layout().addWidget(button_widget)

    # set empty references
    self.epochs: list[int] = []
    self.train_loss: list[float] = []
    self.val_loss: list[float] = []
    self.url: Optional[str] = None
    self.tb = None

__setitem__(key, value) #

Ignore set item.

Parameters:

Name Type Description Default
key Any

Ignored.

required
value Any

Ignored.

required
Source code in src/careamics_napari/widgets/tbplot_widget.py
def __setitem__(self: Self, key: Any, value: Any) -> None:
    """Ignore set item.

    Parameters
    ----------
    key : Any
        Ignored.
    value : Any
        Ignored.
    """
    pass

clear_plot() #

Clear the plot.

Source code in src/careamics_napari/widgets/tbplot_widget.py
def clear_plot(self: Self) -> None:
    """Clear the plot."""
    self.plot.clear()
    self.epochs = []
    self.train_loss = []
    self.val_loss = []

open_tb() #

Open TensorBoard in the browser.

Source code in src/careamics_napari/widgets/tbplot_widget.py
def open_tb(self: Self) -> None:
    """Open TensorBoard in the browser."""
    if self.tb is None and self.train_signal is not None:
        from tensorboard import program

        self.tb = program.TensorBoard()

        path = str(self.train_signal.work_dir / "logs" / "lightning_logs")
        self.tb.configure(argv=[None, "--logdir", path])  # type: ignore
        self.url = self.tb.launch()  # type: ignore

        if self.url is not None:
            webbrowser.open(self.url)
    else:
        if self.url is not None:
            webbrowser.open(self.url)

stop_tb() #

Stop the TensorBoard process.

Currently not implemented.

Source code in src/careamics_napari/widgets/tbplot_widget.py
def stop_tb(self: Self) -> None:
    """Stop the TensorBoard process.

    Currently not implemented.
    """
    # haven't found any good way to stop the tb process, there's currently no API
    # for it
    pass

update_plot(epoch, train_loss, val_loss) #

Update the plot with new data.

Parameters:

Name Type Description Default
epoch int

Epoch number.

required
train_loss float

Training loss.

required
val_loss float

Validation loss.

required
Source code in src/careamics_napari/widgets/tbplot_widget.py
def update_plot(self: Self, epoch: int, train_loss: float, val_loss: float) -> None:
    """Update the plot with new data.

    Parameters
    ----------
    epoch : int
        Epoch number.
    train_loss : float
        Training loss.
    val_loss : float
        Validation loss.
    """
    # clear the plot
    self.plot.clear()

    # add the new points
    self.epochs.append(epoch)
    self.train_loss.append(train_loss)
    self.val_loss.append(val_loss)

    # replot
    self.plot.plot(
        self.epochs,
        self.train_loss,
        pen=pg.mkPen(color=(204, 221, 255)),
        symbol="o",
        symbolSize=2,
        name="Train",
    )
    self.plot.plot(
        self.epochs,
        self.val_loss,
        pen=pg.mkPen(color=(244, 173, 173)),
        symbol="o",
        symbolSize=2,
        name="Val",
    )