Skip to content

UNet

Source

UNet architecture implementation.

UNet

Bases: Module

UNet model.

Adapted for PyTorch from: https://github.com/juglab/n2v/blob/main/n2v/nets/unet_blocks.py.

Parameters:

  • conv_dims (int) –

    Number of dimensions of the convolution layers (2 or 3).

  • num_classes (int, default: 1 ) –

    Number of classes to predict, by default 1.

  • in_channels (int, default: 1 ) –

    Number of input channels, by default 1.

  • depth (int, default: 3 ) –

    Number of downsamplings, by default 3.

  • num_channels_init (int, default: 64 ) –

    Number of filters in the first convolution layer, by default 64.

  • use_batch_norm (bool, default: True ) –

    Whether to use batch normalization, by default True.

  • dropout (float, default: 0.0 ) –

    Dropout probability, by default 0.0.

  • pool_kernel (int, default: 2 ) –

    Kernel size of the pooling layers, by default 2.

  • residual (bool, default: False ) –

    Whether to add a residual connection from the input to the output.

  • final_activation (Optional[Callable], default: NONE ) –

    Activation function to use for the last layer, by default None.

  • n2v2 (bool, default: False ) –

    Whether to use N2V2 architecture, by default False.

  • independent_channels (bool, default: True ) –

    Whether to train the channels independently, by default True.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments, unused.

__init__(conv_dims, num_classes=1, in_channels=1, depth=3, num_channels_init=64, use_batch_norm=True, dropout=0.0, pool_kernel=2, residual=False, final_activation=SupportedActivation.NONE, n2v2=False, independent_channels=True, **kwargs)

Constructor.

Parameters:

  • conv_dims (int) –

    Number of dimensions of the convolution layers (2 or 3).

  • num_classes (int, default: 1 ) –

    Number of classes to predict, by default 1.

  • in_channels (int, default: 1 ) –

    Number of input channels, by default 1.

  • depth (int, default: 3 ) –

    Number of downsamplings, by default 3.

  • num_channels_init (int, default: 64 ) –

    Number of filters in the first convolution layer, by default 64.

  • use_batch_norm (bool, default: True ) –

    Whether to use batch normalization, by default True.

  • dropout (float, default: 0.0 ) –

    Dropout probability, by default 0.0.

  • pool_kernel (int, default: 2 ) –

    Kernel size of the pooling layers, by default 2.

  • residual (bool, default: False ) –

    Whether to add a residual connection from the input to the output.

  • final_activation (Optional[Callable], default: NONE ) –

    Activation function to use for the last layer, by default None.

  • n2v2 (bool, default: False ) –

    Whether to use N2V2 architecture, by default False.

  • independent_channels (bool, default: True ) –

    Whether to train parallel independent networks for each channel, by default True.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments, unused.

forward(x)

Forward pass.

Parameters:

  • x ( torch.Tensor) –

    Input tensor.

Returns:

  • Tensor

    Output of the model.