U-Net:セグメンテーションに特化したネットワーク

目次

今回は,University of Freiburg, Germany の Thomas Brox が主宰する研究室で考案された U-Net と呼ばれるネットワークを紹介します.

この U-Net は,画像のセグメンテーションに特化していて,2015年の ISBI では「the Dental X-Ray Image Segmentation Challenge」と「the Cell Tracking Challenge」の2部門で優勝しています(https://bit.ly/2Qu2CVz).

U-Net は,全層畳み込みネットワーク (Fully Convolution Network,以下 FCN) の 1 種類です.U-Net が一般的な FCN と異なる点として,畳み込まれた画像を decode する際に,encode で使った情報を活用している点が挙げられます.具体的には,図中のグレーの矢印によって,情報を渡しています.この工夫によって,より精度の高い,ピクセル単位での分類が可能になっています.

データセット

今回は次のデータセットを用いて,U-Net の実装を行います.Carvana Image Masking Challenge (https://www.kaggle.com/c/carvana-image-masking-challenge/data)

そして今回,実装する U-Net によって以下のような結果が返ってきます.キレイに車の領域を囲む mask 画像を生成できていることが分かります.

出力

実装



まずは,必要なライブラリをインポートします.今回の実装では Pytorch を用います.
import sys
import os
from optparse import OptionParser
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision.transforms as transforms
from torch.autograd import Function, Variable

from tqdm import tqdm
import pydensecrf.densecrf as dcrf
import random
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
ライブラリのインポート.py


次に,プログラム内で使用する変数を定義します.
args = {}
args['dir_img'] = 'data/train/'
args['dir_mask'] = 'data/train_mask/'
args['dir_img_test'] = 'data/test/'
args['dir_checkpoint'] = 'checkpoint/'
args['val_percent'] = 0.05
args['scale'] = 0.5
args['n'] = 2
args['batch_size'] = 2
args['epoch'] = 5
args['threshold'] = 0.5
変数の定義.py


画像の読み込みおよび,バッチ化を行う関数を定義します.
def to_cropped(args, ids, dir, suffix):
    
    for id, pos in ids:
        img = Image.open(dir + id + suffix)

        w = img.size[0]
        h = img.size[1]
        newW = int(w * args['scale'])
        newH = int(h * args['scale'])

        img = img.resize((newW, newH))    
        img = img.crop((0, 0, newW, newH))
        img = np.array(img, dtype=np.float32)
        
        h = img.shape[0]
        if pos == 0:
            img = img[:, :h]
        else:
            img = img[:, -h:]
        
        yield img

        
def get_img_mask(args, ids):
    img = to_cropped(args, ids, args['dir_img'], '.jpg')
    img = map(lambda x: np.transpose(x, axes=[2, 0, 1]), img)
    img = map(lambda x: x/255, img)

    mask = to_cropped(args, ids, args['dir_mask'], '_mask.gif')

    return zip(img, mask)


def batch(iterable, batch_size):
    
    b = []
    for i, t in enumerate(iterable):
        b.append(t)
        if (i + 1) % batch_size == 0:
            yield b
            b = []

    if len(b) > 0:
        yield b
関数の定義.py


データセットの読み込みを行います.
ids_all = [f[:-4] for f in os.listdir(args['dir_img'])]
ids_all = [(id, i) for i in range(args['n']) for id in ids_all]
random.shuffle(ids_all)
n = int(len(ids_all) * args['val_percent'])
ids = {'train': ids_all[:-n], 'val': ids_all[-n:]}

len_train = len(ids['train'])
len_val = len(ids['val'])
データセットの読み込み.py


U-Net の定義を行います.
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

    
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x
U-Net ネットワークの定義.py


Dice 係数を算出するクラスの定義を行います.
class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union + self.inter) \
                         / self.union * self.union
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)
Dice 係数を算出するクラスを定義


ネットワークの学習を行います.
net = UNet(n_channels=3, n_classes=1).cuda()

optimizer = optim.SGD(
    net.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=0.0005
)

criterion = nn.BCELoss()

for epoch in range(args['epoch']):
    
    train = get_img_mask(args, ids['train'])
    val = get_img_mask(args, ids['val'])
    
    #---- Train section
    epoch_loss = 0
    for i, b in enumerate(batch(train, args['batch_size'])):
        img = np.array([i[0] for i in b]).astype(np.float32)
        mask = np.array([i[1] for i in b])

        img = torch.from_numpy(img).cuda()
        mask = torch.from_numpy(mask).cuda()
        mask_flat = mask.view(-1)

        mask_pred = net(img)
        mask_prob = F.sigmoid(mask_pred)
        mask_prob_flat = mask_prob.view(-1)

        loss = criterion(mask_prob_flat, mask_flat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        if i%10 == 0:
            print('{}/{} ---- loss: {}'.format(i, int(len_train/args['batch_size']), loss.item()))
    
    print('Epoch finished ! Loss: {}'.format(epoch_loss / len_train))

    
    #---- Val section
    val_dice = 0
    for j, b in enumerate(val):
        img =  torch.from_numpy(b[0]).unsqueeze(0).cuda()
        mask = torch.from_numpy(b[1]).unsqueeze(0).cuda()

        mask_pred = net(img)[0]
        mask_prob = F.sigmoid(mask_pred)
        mask_bin = (mask_prob > 0.5).float()
        val_dice += dice_coeff(mask_bin, mask).item()
        
        if j%10 == 0:
            print('val: {}/{}'.format(j, len_val))
    
    torch.save(net.state_dict(), '{}CP{}.pth'.format(args['dir_checkpoint'], epoch + 1))
    print('Checkpoint {} saved !'.format(epoch + 1))
    print('Validation Dice Coeff: {}'.format(val_dice / len_val))
学習.py


テストデータを用いて,学習したネットワークを評価します.
file_img_test = os.listdir(args['dir_img_test'])
random.shuffle(file_img_test)

for i, file in enumerate(file_img_test):
    img_original = Image.open(args['dir_img_test']+file)
    img = img_original
    
    w = img.size[0]
    h = img.size[1]
    
    newW = int(w * args['scale'])
    newH = int(h * args['scale'])

    img = img.resize((newW, newH))    
    img = img.crop((0, 0, newW, newH))
    img = np.array(img, dtype=np.float32)
    img = img / 255
    
    img_left = img[:, :newH]
    img_right = img[:, -newH:]
    
    img_left = np.transpose(img_left, axes=[2, 0, 1])
    img_right = np.transpose(img_right, axes=[2, 0, 1])
    
    img_left = torch.from_numpy(img_left).unsqueeze(0).cuda()
    img_right = torch.from_numpy(img_right).unsqueeze(0).cuda()

    
    with torch.no_grad():
        mask_left = net(img_left)
        mask_right = net(img_right)

        mask_prob_left = F.sigmoid(mask_left).squeeze(0)
        mask_prob_right = F.sigmoid(mask_right).squeeze(0)
        
        tf = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(h),
                transforms.ToTensor()
        ])
        
        mask_prob_left = tf(mask_prob_left.cpu())
        mask_prob_right = tf(mask_prob_right.cpu())
        
        mask_prob_left_np = mask_prob_left.squeeze().cpu().numpy()
        mask_prob_right_np = mask_prob_right.squeeze().cpu().numpy()
        
        mask_prob_np = np.zeros((h, w), np.float32)
        mask_prob_np[:, :w//2+1] = mask_prob_left_np[:, :w//2+1]
        mask_prob_np[:, w//2+1:] = mask_prob_right_np[:, -(w//2-1):]
            
        
        h = mask_prob_np.shape[0]
        w = mask_prob_np.shape[1]

        mask_prob_np = np.expand_dims(mask_prob_np, 0)
        mask_prob_np = np.append(1 - mask_prob_np, mask_prob_np, axis=0)

        d = dcrf.DenseCRF2D(w, h, 2)
        U = -np.log(mask_prob_np)
        U = U.reshape((2, -1))
        U = np.ascontiguousarray(U)
        img = np.ascontiguousarray(np.array(img_original).astype(np.uint8))

        d.setUnaryEnergy(U)

        d.addPairwiseGaussian(sxy=20, compat=3)
        d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)

        mask = d.inference(5)
        mask = np.argmax(np.array(mask), axis=0).reshape((h, w))        
        mask = mask_prob_np > args['threshold']
        mask = Image.fromarray((mask[0] * 255).astype(np.uint8))
        
        
        fig = plt.figure(figsize=(27, 7))

        ax1 = fig.add_subplot(131)    
        ax1.imshow(img_original)
        ax1.set_title('input', fontsize=28)

        ax2 = fig.add_subplot(132)
        ax2.imshow(mask)
        ax2.set_title('output', fontsize=28)

        ax3 = fig.add_subplot(133)
        ax3.imshow(img_original)
        ax3.imshow(mask, alpha=0.8)
        ax3.set_title('input and output', fontsize=28)
        
        plt.show()
テスト.py


以下の結果が出力され,無事に U-Net が機能していることが分かります.

出力