(著)山たー
3D U-netの実装を見ているとforループでmodelを定義していた。
inputs = Input(input_shape) current_layer = inputs levels = list() # add levels with max pooling for layer_depth in range(depth): layer1 = create_convolution_block(input_layer=current_layer, n_filters=n_base_filters*(2**layer_depth), batch_normalization=batch_normalization) layer2 = create_convolution_block(input_layer=layer1, n_filters=n_base_filters*(2**layer_depth)*2, batch_normalization=batch_normalization) if layer_depth < depth - 1: current_layer = MaxPooling3D(pool_size=pool_size)(layer2) levels.append([layer1, layer2, current_layer]) else: current_layer = layer2 levels.append([layer1, layer2]) # add levels with up-convolution or up-sampling for layer_depth in range(depth-2, -1, -1): up_convolution = get_up_convolution(pool_size=pool_size, deconvolution=deconvolution, n_filters=current_layer._keras_shape[1])(current_layer) concat = concatenate([up_convolution, levels[layer_depth][1]], axis=1) current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], input_layer=concat, batch_normalization=batch_normalization) current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], input_layer=current_layer, batch_normalization=batch_normalization) final_convolution = Conv3D(n_labels, (1, 1, 1))(current_layer) act = Activation(activation_name)(final_convolution) model = Model(inputs=inputs, outputs=act)
こんな感じ。ぱっと見てよく分からなかったので、2DのU-netのmodelをforループを用いて定義してみる。
2D U-Netのmodel
まずU-netの実装はこれを用いた。モデルをplot_modelを用いて描画すると、以下のようになる。
文字が潰れて読めないが、縦にすると長いのと、"U"-Netっぽさのために横向きにした。
forループを用いてmodelを構築する
というわけで上の2つの実装をまとめてみると次のようになった。なお、モデル中からDropoutを省いた。
from keras.models import Input, Model from keras.layers import Conv2D, Concatenate, MaxPooling2D from keras.layers import UpSampling2D #Dropout from keras import optimizers from keras.utils import plot_model def unet(pretrained_weights = None, input_size = (256,256,1), depth = 5): inputs = Input(input_size) current_layer = inputs #forループ中でレイヤ名を保存する変数 levels = list() #concatのためにレイヤ名を記録するリスト # add levels with max pooling for layer_depth in range(depth): n_filters = 64*(2**layer_depth) layer1 = Conv2D(n_filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(current_layer) layer2 = Conv2D(n_filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(layer1) if layer_depth < depth - 1: current_layer = MaxPooling2D(pool_size=(2,2))(layer2) levels.append([layer1, layer2, current_layer]) else: current_layer = layer2 levels.append([layer1, layer2]) # add levels with up-sampling for layer_depth in range(depth-2, -1, -1): #depth-2から0まで-1ずつ減る n_filters = 64*(2**layer_depth) up = Conv2D(n_filters, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(current_layer)) concat = Concatenate()([up, levels[layer_depth][1]]) current_layer = Conv2D(n_filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(concat) current_layer = Conv2D(n_filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(current_layer) current_layer = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(current_layer) output = Conv2D(1, 1, activation='sigmoid')(current_layer) model = Model(input = inputs, output = output) model.compile(optimizer = optimizers.Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy']) model.summary() if(pretrained_weights): model.load_weights(pretrained_weights) return model
depthを2にすると次のようになった。
一応できてる。
levelsの中身は次のようになっている(depth=2のとき)。
[[<tf.Tensor 'conv2d_1/Relu:0' shape=(?, 256, 256, 64) dtype=float32>,
<tf.Tensor 'conv2d_2/Relu:0' shape=(?, 256, 256, 64) dtype=float32>,
<tf.Tensor 'max_pooling2d_1/MaxPool:0' shape=(?, 128, 128, 64) dtype=float32>],
[<tf.Tensor 'conv2d_3/Relu:0' shape=(?, 128, 128, 128) dtype=float32>,
<tf.Tensor 'conv2d_4/Relu:0' shape=(?, 128, 128, 128) dtype=float32>]]
levels[0][1]などとすると、
<tf.Tensor 'conv2d_2/Relu:0' shape=(?, 256, 256, 64) dtype=float32>
と帰ってくる。
まとめ
Kerasは層を追加する感じでネットワークを構築できるので、同じ構造が続く場合はforループを使うとコンパクトに実装できる。…が、実装が読みづらい。これは慣れなのだろうか。
あと、もっと簡潔な実装があった。こっちもforループを用いて定義している。
コメントをお書きください