Brief introduction to Unet

Structure

1

The structure of Unet is symmetric, like a shape of 'U'.

On the right side the blue box is feature map derived by convolution-deconvolution. And the white box is the just the same as the blue box on the left side. The grey arrow shows skip-connection-- putting the two parts together.

The encoder part is quite simple--3x3 convolutional kernel + maxpooling.

The decoder part is designed to restore the original shape of the image using upsampling.

In Unet, upsampling is implmented in bilinear interpolation.

torch.nn.Upsample(mode=bilinear)

Since upsampling will lose the detailed information of orginal image, Unet try to use skip-connection to fix this problem.

torch.cat([low_layer_features, deep_layer_features], dim=1)

(Note: In FCN, two parts are added in pixel level, however, in Unet, two part are concatenated together.--concatenating remains more details with more time consumed).

Code (keras)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def Unet():
inputs = Input((img_rows, img_cols, 3)) # RGB net
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
# pool1 = Dropout(0.25)(pool1)
# pool1 = BatchNormalization()(pool1)

conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
# pool2 = Dropout(0.5)(pool2)
# pool2 = BatchNormalization()(pool2)

conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
# pool3 = Dropout(0.5)(pool3)
# pool3 = BatchNormalization()(pool3)

conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
# pool4 = Dropout(0.5)(pool4)
# pool4 = BatchNormalization()(pool4)

conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(
2, 2), padding='same')(conv5), conv4], axis=3)
# up6 = Dropout(0.5)(up6)
# up6 = BatchNormalization()(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)

up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(
2, 2), padding='same')(conv6), conv3], axis=3)
# up7 = Dropout(0.5)(up7)
# up7 = BatchNormalization()(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)

up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(
2, 2), padding='same')(conv7), conv2], axis=3)
# up8 = Dropout(0.5)(up8)
# up8 = BatchNormalization()(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)

up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
2, 2), padding='same')(conv8), conv1], axis=3)
# up9 = Dropout(0.5)(up9)
# up9 = BatchNormalization()(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)

# conv9 = Dropout(0.5)(conv9)

conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

model = Model(inputs=[inputs], outputs=[conv10])

model.compile(optimizer=Adam(lr=1e-5),
loss=dice_coef_loss, metrics=[dice_coef])

return model