KerasでU-Netを実装してみる

以前YOLOなどを試していたのですが、領域抽出をやったことがなかったので、
U-Netを試してみました。

論文は下記ですが、しっかり読んではおりません。
ネットワーク構造の絵を見て実装してみた感じです。
https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

学習と評価に使うデータは下記のものを使用しました。
Daimler Pedestrian Segmentation Benchmark

GithubのGistという機能を使うと、はてなブログにJupyter Notebookのコードを貼れるそうですが、
アカウント持ってないので、普通に貼ります。

import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

img_size=256

input_path=glob.glob("data/testData/left_images/*")
output_path=glob.glob("data/testData/left_groundTruth/*")
img_num=len(input_path)

train_x=[]
train_y=[]

for path in input_path:
    image=cv2.imread(path)/255.
    image=cv2.resize(image, (img_size, img_size))
    train_x.append(image)

for path in output_path:
    image=cv2.imread(path,0)/255.
    image=cv2.resize(image, (img_size, img_size))
    train_y.append(image)
    
train_x=np.array(train_x)
train_y=np.array(train_y)

from keras.layers import Conv2D, Input,  MaxPooling2D, UpSampling2D, Dropout
from keras.models import Model
from keras.layers.merge import concatenate
from keras.optimizers import RMSprop, Adam
from keras import backend as K

inputs=Input((img_size, img_size,3))

enc1=Conv2D(64, (3, 3), padding='same', activation='relu')(inputs)
enc2=Conv2D(64, (3, 3), padding='same', activation='relu')(enc1)

enc3=MaxPooling2D(pool_size=(2, 2))(enc2)
#enc3=Dropout(0.5)(enc3)
enc4=Conv2D(128, (3, 3), padding='same', activation='relu')(enc3)
enc5=Conv2D(128, (3, 3), padding='same', activation='relu')(enc4)
enc6=Conv2D(128, (3, 3), padding='same', activation='relu')(enc5)

enc7=MaxPooling2D(pool_size=(2, 2))(enc6)
#enc7=Dropout(0.5)(enc7)
enc8=Conv2D(256, (3, 3), padding='same', activation='relu')(enc7)
enc9=Conv2D(256, (3, 3), padding='same', activation='relu')(enc8)
enc10=Conv2D(256, (3, 3), padding='same', activation='relu')(enc9)

enc11=MaxPooling2D(pool_size=(2, 2))(enc10)
#enc11=Dropout(0.5)(enc11)
enc12=Conv2D(512, (3, 3), padding='same', activation='relu')(enc11)
enc13=Conv2D(512, (3, 3), padding='same', activation='relu')(enc12)
enc14=Conv2D(512, (3, 3), padding='same', activation='relu')(enc13)

enc15=MaxPooling2D(pool_size=(2, 2))(enc14)
#enc15=Dropout(0.5)(enc15)
enc16=Conv2D(1024, (3, 3), padding='same', activation='relu')(enc15)
enc17=Conv2D(1024, (3, 3), padding='same', activation='relu')(enc16)
enc18=Conv2D(1024, (3, 3), padding='same', activation='relu')(enc17)

dec1=UpSampling2D(size=(2, 2))(enc18)
dec2=concatenate([dec1, enc14], axis=-1)
dec2=Dropout(0.5)(dec2)
dec3=Conv2D(512, (3, 3), padding='same', activation='relu')(dec2)
dec4=Conv2D(512, (3, 3), padding='same', activation='relu')(dec3)

dec5=UpSampling2D(size=(2, 2))(dec4)
dec6=concatenate([dec5, enc10], axis=-1)
dec6=Dropout(0.5)(dec6)
dec7=Conv2D(256, (3, 3), padding='same', activation='relu')(dec6)
dec8=Conv2D(256, (3, 3), padding='same', activation='relu')(dec7)

dec9=UpSampling2D(size=(2, 2))(dec8)
dec10=concatenate([dec9, enc6], axis=-1)
dec10=Dropout(0.5)(dec10)
dec11=Conv2D(128, (3, 3), padding='same', activation='relu')(dec10)
dec12=Conv2D(128, (3, 3), padding='same', activation='relu')(dec11)

dec13=UpSampling2D(size=(2, 2))(dec12)
dec14=concatenate([dec13, enc2], axis=-1)
dec14=Dropout(0.5)(dec14)
dec15=Conv2D(64, (3, 3), padding='same', activation='relu')(dec14)
dec16=Conv2D(64, (3, 3), padding='same', activation='relu')(dec15)

dec17=Conv2D(1, (3, 3), padding='same', activation='sigmoid')(dec16)

model=Model(input=inputs, output=dec17)

def dice_coef(y_true, y_pred):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    intersection = K.sum(y_true * y_pred)
    
    if (K.sum(y_true) + K.sum(y_pred) == 0):
        return 1.0
    else:
        return (2.0 * intersection)/ (K.sum(y_true) + K.sum(y_pred))

def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

model.compile(loss=dice_coef_loss, optimizer=RMSprop(lr=1e-4), metrics=[dice_coef])

batch_size=8
epoch_num=100

history = model.fit(train_x, train_y, batch_size=batch_size, epochs=epoch_num)

いろいろ試した結果、Dropoutを入れています。

学習画像

f:id:oki-lab:20190923090355j:plain

推定画像(学習済画像)

f:id:oki-lab:20190923090408j:plain

評価用画像1

f:id:oki-lab:20190923090452j:plain

推定画像1(評価用画像1)

f:id:oki-lab:20190923090534j:plain

評価用画像2

f:id:oki-lab:20190923090548j:plain

推定画像2(評価用画像2)

f:id:oki-lab:20190923090602j:plain

割とうまくいった気がします。