Keras で GAN の練習

今週強化学習アーキテクチャ勉強会で GAN の話を聴いてきたので(勉強会自体は GAN ではなくて GAN の手法の強化学習への応用が主題ですが)、GAN を手元で動かしてみたいと思います。

参考文献

「keras gan example」と検索すると色々出てきますが、以下の記事を参考にしたいと思います。今回書いたスクリプトはほぼ以下の記事と同じです(ただし、訓練ルールがおかしい可能性があります;後述)。
GAN by Example using Keras on Tensorflow Backend – Towards Data Science

GAN(Generative Adversarial Networks)って何

※ これは参考文献の記述というより自分で適当に書いています。

  • 何かの模造品が出てくるジェネレータです。例えば「日本人の顔っぽい画像」を生成します。もちろん、このジェネレータをつくるには日本人の顔の画像のサンプルデータセットを用意して、どんな画像を出せば日本人の顔っぽいのかを訓練する必要があります。
  • このジェネレータを訓練する過程で、模造品と本物を識別するディスクリミネータというのを一緒に合わせて訓練するのがミソです。
    • ジェネレータは乱数から模造データへの写像です。
    • ディスクリミネータはデータ(模造でも本物でも)から本物らしい確率 0~1 への写像です。
    • つまり、ジェネレータ とディスクリミネータをがっちゃんこすると乱数から本物らしい確率 0~1 への写像です(以下、がっちゃんこしたのをアドバーサリアルモデルとよびます)。
    • 訓練は以下の繰り返しです。
      • まず n 個の乱数を用意して、ジェネレータに入れて n 個の模造データを生成します。
      • 次にディスクリミネータを訓練します; n 個の模造データと n 個の本物データを受け取って、模造品を入れたら0、本物を入れたら1が出てくるように訓練します。
      • 次にアドバーサリアルモデルを訓練します; また新しく n 個の乱数を用意して、どの乱数を入れても1が出てくるように訓練します。

※ ただし、アドバーサリアルモデルの訓練のときディスクリミネータ部分のネットワークの重みは固定しておかないとディスクリミネータまで更新されてしまいます。参考の記事のスクリプトはディスクリミネータが2回更新されているようにみえるので確認中です。

今回やること
  • MNIST から訓練して模造手書き数字を生成します。
    • ジェネレータは100次元の乱数から模造手書き数字を生成します。最後の活性化を除き活性化の直前に必ず Batch Normalization します(通常の訓練データからの訓練と違って、入力が一様乱数だから入念な安定化が必要なのですかね?)。
    • ディスクリミネータは普通の手書き数字分類と違って、maxプーリングしません。maxプーリングは手書き数字の「4」が少しずれたものでも「4」と識別てきるようにする効果がありますが、今回ディスクリミネータに求められるのはそういう能力じゃないからということなのですかね?
実行結果

ディスクリミネータ → アドバーサリアルモデルの順に何か訓練は進んでいるようです(?)。

0: [D loss: 0.693790, acc: 0.447266]  [A loss: 2.802749, acc: 0.000000]
1: [D loss: 0.605273, acc: 0.958984]  [A loss: 5.334227, acc: 0.000000]
2: [D loss: 0.443096, acc: 0.882812]  [A loss: 1.039907, acc: 0.058594]
3: [D loss: 1.566417, acc: 0.500000]  [A loss: 14.592390, acc: 0.000000]
4: [D loss: 0.264856, acc: 0.890625]  [A loss: 0.048600, acc: 1.000000]
5: [D loss: 0.341633, acc: 0.830078]  [A loss: 7.195857, acc: 0.000000]
6: [D loss: 0.095739, acc: 0.990234]  [A loss: 0.227625, acc: 0.953125]
7: [D loss: 0.088694, acc: 1.000000]  [A loss: 0.125637, acc: 0.992188]
8: [D loss: 0.090487, acc: 0.992188]  [A loss: 0.058483, acc: 1.000000]

41ステップ目からなんかディスクリミネータが識別しづらくなっているようです。41ステップ目や43ステップ目は D の acc が 0.5 で A の acc が 1.0 なのでディスクリミネータが全てのデータを本物と判定してしまっているようです。

36: [D loss: 0.008699, acc: 0.998047]  [A loss: 0.000055, acc: 1.000000]
37: [D loss: 0.009811, acc: 0.998047]  [A loss: 0.000191, acc: 1.000000]
38: [D loss: 0.012975, acc: 0.998047]  [A loss: 0.000129, acc: 1.000000]
39: [D loss: 0.010878, acc: 0.998047]  [A loss: 0.016210, acc: 0.992188]
40: [D loss: 0.130817, acc: 0.957031]  [A loss: 16.118101, acc: 0.000000]
41: [D loss: 5.751718, acc: 0.500000]  [A loss: 0.000000, acc: 1.000000]
42: [D loss: 1.771331, acc: 0.587891]  [A loss: 4.267523, acc: 0.250000]
43: [D loss: 7.854210, acc: 0.500000]  [A loss: 0.000000, acc: 1.000000]
44: [D loss: 7.210688, acc: 0.500000]  [A loss: 4.953451, acc: 0.253906]
45: [D loss: 6.669549, acc: 0.500000]  [A loss: 11.922536, acc: 0.015625]

50ステップ後のジェネレータ出力は以下です。なんかもにょもにょしていて全然手書き数字ではないです。

f:id:cookie-box:20171216161856p:plain:w380

参考文献には1000ステップくらい回して手書き数字っぽい出力を学習できている例が載っていますが、このまま学習を続けるとよくなるのでしょうか。いま作業しているマシンには GPU 積んでいないのでやるなら一晩かけてみないとよくわかりません。
しかし、GPU がなくても、もうちょっと学習が上手くいっているのかどうか知りたいものです。そこで、以下の強硬手段をとります。

  • 数字が10種類もあるのがよくない。ここはもう数字の「1」のみに絞る。棒くらい学んでほしい。
  • ディスクリミネータが4回も畳み込んでいる。畳み込みすぎ。時間がかかるので最後の畳込みを削る。

このようにして実行してみると着実に数字の「1」への道を歩んでいるように見えます。よかった。

25ステップ
f:id:cookie-box:20171216210234p:plain:w210
50ステップ
f:id:cookie-box:20171216210259p:plain:w210
75ステップ
f:id:cookie-box:20171216210320p:plain:w210
100ステップ
f:id:cookie-box:20171216210609p:plain:w210
125ステップ
f:id:cookie-box:20171216210648p:plain:w210
150ステップ
f:id:cookie-box:20171216210801p:plain:w210
175ステップ
f:id:cookie-box:20171216211527p:plain:w210
200ステップ
f:id:cookie-box:20171216211539p:plain:w210
225ステップ
f:id:cookie-box:20171216211550p:plain:w210

スクリプト

変数名と処理の順序を一部変更している以外参考文献のコードと同じです。以下注意書きです。

  • 参考文献の記事は後半以降 model という言葉を「ネットワーク構造 + 訓練ルール(損失関数と勾配法)」という意味合いでつかっているようです。スクリプトでもネットワーク構造のみ(self.D)とネットワーク構造+訓練ルール(self.DM)を別のメンバとして持っています。
    • 通常の訓練データの識別や回帰ではこれをわざわざ分けないですが、GAN の訓練では「ディスクリミネータ構造 + ディスクリミネータの訓練ルール」による訓練と「ジェネレータ構造 + ディスクリミネータ構造 + アドバーサリアルモデルの訓練ルール」による訓練の2種類の訓練をしなければならないので、構造は構造単体で持っておかなければならないのですね。
  • Windows などでユーザ名が日本語になっているなどすると(変えたいのですが…)一時ファイルのパスに日本語が交じりデータのダウンロードに失敗します。環境変数 TMP, TEMP を日本語のないパスに変更すると解決します(import - windows10環境でtensorflowを動かしたい(69477)|teratail)。
  • MNIST_DCGAN クラスに1種類の数字に絞るかどうかのコメントアウトがあるので適宜変更してください。さらに計算量を削りたい人は、上でやったようにディスクリミネータの最後の Conv2D と Dropout を削るとか、バッチサイズを小さくするとかするといいと思います。
# -*- coding: utf-8 -*-
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.layers import Conv2D, LeakyReLU, Dense, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, Conv2DTranspose, UpSampling2D, Reshape
from keras.models import Sequential
from keras.optimizers import Adam, RMSprop
import matplotlib.pyplot as plt

class DCGAN:
  def __init__(self, img_rows=28, img_cols=28, channel=1):
    self.img_rows = img_rows
    self.img_cols = img_cols
    self.channel = channel
    self.D = None   # ディスクリミネータ(のネットワーク構造だけ)
    self.G = None   # ジェネレータ(のネットワーク構造だけ)
    self.DM = None  # ディスクリミネータモデル(ディスクリミネータ + 訓練ルール)
    self.AM = None  # アドバーサリアルモデル(ジェネレータ + ディスクリミネータ + 訓練ルール)

  def discriminator_network(self): # ディスクリミネータ: 畳み込み x 4回
    if self.D:
      return self.D
    self.D = Sequential()
    depth = 64
    dropout = 0.4
    input_shape = (self.img_rows, self.img_cols, self.channel)
    self.D.add(Conv2D(depth*1, 5, strides=2, padding='same', activation=LeakyReLU(alpha=0.2),
                      input_shape=input_shape))                 # 28 x 28 x 1 --> 14 x 14 x 64
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*2, 5, strides=2, padding='same',
                      activation=LeakyReLU(alpha=0.2)))         # 14 x 14 x 64 --> 7 x 7 x 128
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*4, 5, strides=2, padding='same', 
                      activation=LeakyReLU(alpha=0.2)))         # 7 x 7 x 128 --> 4 x 4 x 256
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*8, 5, strides=1, padding='same', 
                      activation=LeakyReLU(alpha=0.2)))         # 4 x 4 x 256 --> 4 x 4 x 512
    self.D.add(Dropout(dropout))
    self.D.add(Flatten())                                       # 4 x 4 x 512 --> 8192
    self.D.add(Dense(1, activation='sigmoid'))                  # 8192 --> 1
    return self.D

  def generator_network(self): # ジェネレータ: 逆畳み込み x 4回
    if self.G:
      return self.G
    self.G = Sequential()
    dropout = 0.4
    depth = 64 + 64 + 64 + 64
    dim = 7
    self.G.add(Dense(dim*dim*depth, input_dim=100))             # 100 --> 12544
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Reshape((dim, dim, depth)))                       # 12544 --> 7 x 7 x 256
    self.G.add(Dropout(dropout))
    self.G.add(UpSampling2D())                                   # 7 x 7 x 256 --> 14 x 14 x 256
    self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) # 14 x 14 x 256 --> 14 x 14 x 128
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(UpSampling2D())                                   # 14 x 14 x 128 --> 28 x 28 x 128
    self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same')) # 28 x 28 x 128 --> 28 x 28 x 64
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same')) # 28 x 28 x 64 --> 28 x 28 x 32
    self.G.add(BatchNormalization(momentum=0.9))
    self.G.add(Activation('relu'))
    self.G.add(Conv2DTranspose(1, 5, padding='same', 
               activation='sigmoid'))                            # 28 x 28 x 32 --> 28 x 28 x 1
    return self.G

  def discriminator_model(self):
    if self.DM:
      return self.DM
    optimizer = RMSprop(lr=0.0002, decay=6e-8)
    self.DM = Sequential()
    self.DM.add(self.discriminator_network())
    self.DM.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return self.DM

  def adversarial_model(self):
    if self.AM:
      return self.AM
    optimizer = RMSprop(lr=0.0001, decay=3e-8)
    self.AM = Sequential()
    self.AM.add(self.generator_network())
    self.AM.add(self.discriminator_network())
    self.AM.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return self.AM

class MNIST_DCGAN(object):
  def __init__(self):
    self.img_rows = 28
    self.img_cols = 28
    self.channel = 1

    # データ読み込み
    # データをしぼらない場合
    #self.x_train = input_data.read_data_sets("mnist", one_hot=True).train.images
    # データをしぼる場合
    mnist = input_data.read_data_sets("mnist", one_hot=False)
    images = mnist.train.images
    labels = mnist.train.labels
    images = images[np.where(labels == 1)] # 「1」だけにしぼる場合
    
    self.x_train = images
    self.x_train = self.x_train.reshape(-1, self.img_rows, self.img_cols, 1).astype(np.float32)
    self.DCGAN = DCGAN()
    self.discriminator_model =  self.DCGAN.discriminator_model()
    self.adversarial_model = self.DCGAN.adversarial_model()
    self.generator_network = self.DCGAN.generator_network()

  def train(self, train_steps=2000, batch_size=256, save_interval=0):
    # 学習の途中でジェネレータ出力を吐き出す場合、それ用の乱数を確保しておく
    noise_input = None
    if save_interval > 0:
      noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
    
    # GAN の訓練
    for i in range(train_steps):
      # (1) batch_size 個の乱数から batch_size 個の模造データ作成
      noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
      images_fake = self.generator_network.predict(noise)
      # (2) batch_size 個の本物データと batch_size 個の模造データでディスクリミネータを訓練
      images_train = self.x_train[np.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]
      x = np.concatenate((images_train, images_fake))
      y = np.ones([2*batch_size, 1])
      y[batch_size:, :] = 0
      d_loss = self.discriminator_model.train_on_batch(x, y)
      # (3) batch_size 個の乱数でアドバーサリアルモデル(ジェネレータ+ディスクリミネータ)を訓練
      y = np.ones([batch_size, 1])
      noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
      a_loss = self.adversarial_model.train_on_batch(noise, y)
      
      log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
      log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
      print(log_mesg)
      
      if save_interval > 0:
        if (i+1) % save_interval == 0:
          self.plot_images(samples=noise_input.shape[0], noise=noise_input, step=(i+1))

  def plot_images(self, fake=True, samples=16, noise=None, step=0):
    filename = 'mnist.png'
    if fake:
      if noise is None:
        noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
      else:
        filename = "mnist_%d.png" % step
      images = self.generator_network.predict(noise)
    else:
      i = np.random.randint(0, self.x_train.shape[0], samples)
      images = self.x_train[i, :, :, :]
    plt.figure(figsize=(10,10))
    for i in range(images.shape[0]):
      plt.subplot(4, 4, i+1)
      image = images[i, :, :, :]
      image = np.reshape(image, [self.img_rows, self.img_cols])
      plt.imshow(image, cmap='gray')
      plt.axis('off')
    plt.tight_layout()
    plt.savefig(filename)
    plt.close('all')

if __name__ == '__main__':
  mnist_dcgan = MNIST_DCGAN()
  mnist_dcgan.train(train_steps=10000, batch_size=256, save_interval=25)
  mnist_dcgan.plot_images(fake=True)
  mnist_dcgan.plot_images(fake=False)