3D UNet

Understanding and Implementing 3D UNet for Medical Image Segmentation in PyTorch

Introduction to 3D UNet

3D UNet is a powerful convolutional neural network architecture widely utilized for image segmentation tasks, particularly in medical imaging applications such as MRI and CBCT scans. It has proven to be one of the most effective methods for delineating structures within volumetric data.

This article provides an in-depth introduction to the architecture of 3D UNet and presents a PyTorch implementation along with detailed explanations of each part of the code.

Prerequisite

Before diving into the implementation, ensure you have the necessary Python packages installed:

1
2
3
4
5
6
import nibabel as nib
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

GPU Usage: Given the computational intensity of 3D UNet, it is highly recommended to leverage GPU acceleration for training:

1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

UNet Components

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), # keep the same size
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True), # save memory usage
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2), # reduce the size by half, output size = input size // 2
DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)

class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super(Up, self).__init__()
# ConvTranspose3d = dilating(with zeros) + standard cov
# stride: how much the output size is increased; kernel_size: affect overlap
self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1) # upsampling
# padding x1 to x2 shape
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2,
diffZ // 2, diffZ - diffZ // 2])
x = torch.cat([x2, x1], dim=1) # concat in channel dimension
return self.conv(x)

Understanding UNet Components:

  1. DoubleConv: The DoubleConv class defines a basic unit consisting of two 3D convolutional layers with batch normalization and ReLU activation. It is the fundamental building block for both encoding and decoding stages of the UNet.

  2. Down: The Down class represents the down-sampling or encoding part of the UNet. It combines a 3D max-pooling layer to reduce spatial dimensions and the DoubleConv unit to capture hierarchical features.

  3. Up: The Up class corresponds to the up-sampling or decoding section. It employs 3D transposed convolution (deconvolution) for up-sampling and concatenates the features from the corresponding encoding layer before applying the DoubleConv unit.

UNet Model

Putting Components All Together:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class UNet3D(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet3D, self).__init__()
self.n_channels = n_channels # input channels = 1 for grayscale image
self.n_classes = n_classes # output classes = 2 for teeth/non-teeth
# input
self.input = DoubleConv(n_channels, 64)
# downsample
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
# upsample
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
# output
self.output = nn.Conv3d(64, n_classes, kernel_size=1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x1 = self.input(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
output = self.output(x)
output = self.sigmoid(output)
return output

The main UNet3D class assembles the entire 3D UNet model. It consists of an input layer, four encoding blocks (Down), four decoding blocks (Up), and an output layer with a sigmoid activation for binary segmentation.

Loss Function

The dice_loss function computes the Dice coefficient and minimizes the dissimilarity between the predicted and ground truth masks.

1
2
3
4
5
6
7
8
9
10
def dice_loss(pred, target, smooth=1e-5):
assert target.size() == pred.size()
pred, target = pred.contiguous(), target.contiguous()
pred_flat, target_flat = pred.view(pred.size(0), -1), target.view(target.size(0), -1) # flatten and keep the batch size
# intersection = AnB, union = A+B, dice = 1 - (2*AnB)/(A+B), smooth is used to avoid 0
intersection = (pred_flat * target_flat).sum(1)
union = pred_flat.sum(1) + target_flat.sum(1)

dice = (2 * intersection + smooth) / (union + smooth) # add smooth to avoid 0
return 1 - dice.mean()

Closing Thoughts

3D UNet stands out as a robust and straightforward model architecture, particularly well-suited for medical image segmentation tasks, such as MRI and CBCT scans. Its strength lies in its ability to capture intricate spatial dependencies and hierarchical features within volumetric data. However, delving into the details of implementation and training may reveal challenges, including the need for careful hyperparameter tuning, handling class imbalances, and addressing data augmentation strategies.

Pros:

  • Robust spatial feature capture: 3D UNet excels in capturing spatial dependencies, making it effective for volumetric medical image segmentation.
  • Hierarchical feature extraction: The model's architecture allows for the extraction of hierarchical features, aiding in the accurate delineation of structures.

Cons:

  • Computational Intensity: Training a 3D UNet can be computationally intensive, necessitating powerful hardware resources, preferably GPUs.
  • Implementation Nuances: Achieving optimal performance may require fine-tuning various aspects, such as learning rates and regularization techniques.

Current Application:

3D UNet finds widespread application in medical imaging, contributing to tasks such as organ segmentation, tumor detection, and anatomical structure delineation. Its ability to handle volumetric data makes it a go-to choice for three-dimensional medical image analysis.

For those seeking a more convenient solution, nnUNet provides a pre-configured framework that encapsulates best practices and offers a user-friendly approach to deploying segmentation models.