Keras で GAN の練習(その2)

前回の記事で GAN を動かしてみたのですが、実装があやしいのでまた別の記事を参考にしてみます。

参考文献

以下の記事を参考にします。やることは前回と同じで手書き数字の模造です。
MNIST Generative Adversarial Model in Keras

  • この記事ではディスクリミネータのことをアドバーサリアルモデルといっている。
    • 昨日の記事ではジェネレータ+ディスクリミネータのことをアドバーサリアルモデルといっていた。
  • 記事の初っ端から「freeze the weights in the adversarial part of the network, and train the generative network weights」、つまりディスクリミネータ部分のネットワークの重みを固定してジェネレータを訓練するといっているので、昨日の記事で気になった部分はこちらの記事では大丈夫そうです。
  • こちらの記事は Sequential モデルではなくて functional API をつかっています。
  • ディスクリミネータの出力層がの次元が2になっています。つまり、本物なら [0, 1] 、模造品なら [1, 0] を目標出力とします。
    • 2クラス分類では1次元にすることが多いと思っていました。どっちでもいいと思いますが。
    • ただ、2クラス分類で出力層を2次元にする場合、判定を誤ったときに「正しいクラスであると考えた度合いが小さかった」のか、「正しいクラスであると考えた度合いは大きかったが、それ以上に誤ったクラスだと考えた度合いが大きかった」のかの区別は付くと思います(最終出力がソフトマックスされる前をみれば)。それをみたい機会があるのかはわかりませんが。
実行結果

ジェネレータとディスクリミネータをがっちゃんこした状態が以下です。

_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_3 (InputLayer)         (None, 100)               0
_________________________________________________________________
model_1 (Model)              (None, 28, 28, 1)         4341801
_________________________________________________________________
model_2 (Model)              (None, 2)                 9707266
=================================================================
Total params: 14,049,067
Trainable params: 13,970,367
Non-trainable params: 78,700
_________________________________________________________________

ただ肝心の訓練が上手くいっていないので、API はこのままに前回の記事のモデル構造を適用してみたいと思います。

スクリプト

元の記事は Keras のバージョンが手元より古いので、一部サポートされなくなってしまっていた機能がありました(BatchNormalization の mode=2)。

TypeError: The `mode` argument of `BatchNormalization` no longer exists. `mode=1`  and  `mode=2` are no longer supported.

Normalization Layers - Keras 1.2.2 Documentation
元の記事に比べて色々関数にくくり出しています。また、TensorFlow バックエンドの Keras 2.0 で動くように全体的に変更してあります。

# -*- coding: utf-8 -*-
import numpy as np
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Reshape, Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import *
from keras.activations import *
from keras.optimizers import *
from keras.datasets import mnist
import matplotlib.pyplot as plt
from tqdm import tqdm

# ---------- ネットワークの訓練可能オンオフを制御する ----------
def make_trainable(net, trainable=False):
  net.trainable = trainable
  for l in net.layers:
    l.trainable = trainable

# ---------- ジェネレータの生成 ----------
def create_generator(opt):
  g_input = Input(shape=[100])
  H = Dense(14*14*200, init='glorot_normal')(g_input)
  H = BatchNormalization()(H)
  H = Activation('relu')(H)
  H = Reshape([14, 14, 200])(H)
  H = UpSampling2D(size=(2, 2))(H)
  H = Conv2D(100, (3, 3), padding='same', init='glorot_uniform')(H)
  H = BatchNormalization()(H)
  H = Activation('relu')(H)
  H = Conv2D(50, (3, 3), padding='same', init='glorot_uniform')(H)
  H = BatchNormalization()(H)
  H = Activation('relu')(H)
  H = Conv2D(1, (1, 1), padding='same', init='glorot_uniform')(H)
  g_V = Activation('sigmoid')(H)
  generator = Model(g_input, g_V)
  generator.compile(loss='binary_crossentropy', optimizer=opt)
  generator.summary()
  return generator

# ---------- ディスクリミネータの生成 ----------
def create_discriminator(shp, dropout_rate, dopt):
  d_input = Input(shape=shp)
  H = Conv2D(256, (5, 5), subsample=(2, 2), padding='same', activation='relu')(d_input)
  H = LeakyReLU(0.2)(H)
  H = Dropout(dropout_rate)(H)
  H = Conv2D(512, (5, 5), subsample=(2, 2), padding='same', activation='relu')(H)
  H = LeakyReLU(0.2)(H)
  H = Dropout(dropout_rate)(H)
  H = Flatten()(H)
  H = Dense(256)(H)
  H = LeakyReLU(0.2)(H)
  H = Dropout(dropout_rate)(H)
  d_V = Dense(2, activation='softmax')(H)
  discriminator = Model(d_input,d_V)
  discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
  discriminator.summary()
  return discriminator

# ---------- アドバーサリアルモデルの生成 ----------
def create_adversarial_model(generator, discriminator, opt):
  gan_input = Input(shape=[100])
  H = generator(gan_input)
  gan_V = discriminator(H)
  GAN = Model(gan_input, gan_V)
  GAN.compile(loss='categorical_crossentropy', optimizer=opt)
  GAN.summary()
  return GAN

# ---------- ディスクリミネータを1バッチ分トレーニングする ----------
def train_discriminator_1batch(discriminator, X_train, batch_size):
  image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size),:,:,:]
  noise_gen = np.random.uniform(0, 1, size=[batch_size, 100])
  generated_images = generator.predict(noise_gen)
  
  X = np.concatenate((image_batch, generated_images))
  y = np.zeros([2*batch_size, 2])
  y[:batch_size, 1] = 1 # 本物データのときのディスクリミネータの期待出力は y=[0, 1]
  y[batch_size:, 0] = 1 # 模造データのときのディスクリミネータの期待出力は y=[1, 0]
  
  make_trainable(discriminator, True)
  d_loss = discriminator.train_on_batch(X, y)
  return d_loss

# ---------- アドバーサリアルモデルを1バッチ分トレーニングする ----------
def train_GAN_1batch(discriminator, GAN, X_train, batch_size):
  noise_tr = np.random.uniform(0, 1, size=[batch_size, 100])
  y = np.zeros([batch_size, 2])
  y[:, 1] = 1 # 本物データと判断してほしいので期待出力は y=[0, 1]
  
  make_trainable(discriminator, False) # ディスクリミネータの重み係数更新は忘れずにオフ
  g_loss = GAN.train_on_batch(noise_tr, y)
  return g_loss

# ---------- 損失をプロットする ----------
def plot_loss(losses, filename='loss.png'):
  plt.figure(figsize=(10,8))
  plt.plot(losses["d"], label='discriminitive loss')
  plt.plot(losses["g"], label='generative loss')
  plt.legend()
  plt.savefig(filename)
  plt.close('all')

# ---------- ジェネレータが出力する模造データをプロットする ----------
def plot_gen(noise, generator, filename='result.png'):
  generated_images = generator.predict(noise)
  plt.figure(figsize=(10,10))
  for i in range(generated_images.shape[0]):
    plt.subplot(4, 4, i+1)
    img = generated_images[i,:,:,:]
    img = np.reshape(img, [28, 28])
    plt.imshow(img, cmap='gray')
    plt.axis('off')
  plt.tight_layout()
  plt.savefig(filename)
  plt.close('all')

# ---------- ネットワークをトレーニングする(メイン) ----------
def train_GAN(generator, discriminator, GAN, X_train, batch_size=32, steps=50):
  # ディスクリミネータの事前学習
  for i in range(5):
    train_discriminator_1batch(discriminator, X_train, batch_size=10)
  
  losses = {"d":[], "g":[]} # 損失の記録用
  noise_for_plot = np.random.uniform(0, 1, size=[16, 100]) # 途中経過出力用のノイズ
  for i in tqdm(range(steps)):
    # ディスクリミネータの学習
    d_loss = train_discriminator_1batch(discriminator, X_train, batch_size)
    losses["d"].append(d_loss)
    
    # アドバーサリアルモデルの学習
    g_loss = train_GAN_1batch(discriminator, GAN, X_train, batch_size)
    losses["g"].append(g_loss)
    
    # プロット
    if i%25 == 25-1:
      plot_loss(losses)
      plot_gen(noise_for_plot, generator, "result_%d.png" % i)

# ========================= メイン処理 =========================
if __name__ == '__main__':
  # データ読み込みとプレ処理
  img_rows, img_cols = 28, 28
  (X_train, y_train), (X_test, y_test) = mnist.load_data()
  X_train = X_train[np.where(y_train == 1)] # 「1」のみにしぼる
  X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1).astype('float32')
  X_train /= 255.0
  shp = X_train.shape[1:]
  
  # 訓練設定
  dropout_rate = 0.25
  dopt = Adam(lr=1e-3)
  opt = Adam(lr=1e-4)
  
  # ネットワークの生成
  generator = create_generator(opt)
  discriminator = create_discriminator(shp, dropout_rate, dopt)
  GAN = create_adversarial_model(generator, discriminator, opt)
  
  # ネットワークの訓練
  train_GAN(generator, discriminator, GAN, X_train)
その他

tqdm という進捗バーを表示してくれるパッケージをはじめて知りました。
Pythonで進捗バーを表示する(tqdm) - naritoブログ