前回の記事で GAN を動かしてみたのですが、実装があやしいのでまた別の記事を参考にしてみます。
参考文献
以下の記事を参考にします。やることは前回と同じで手書き数字の模造です。
MNIST Generative Adversarial Model in Keras
- この記事ではディスクリミネータのことをアドバーサリアルモデルといっている。
- 昨日の記事ではジェネレータ+ディスクリミネータのことをアドバーサリアルモデルといっていた。
- 記事の初っ端から「freeze the weights in the adversarial part of the network, and train the generative network weights」、つまりディスクリミネータ部分のネットワークの重みを固定してジェネレータを訓練するといっているので、昨日の記事で気になった部分はこちらの記事では大丈夫そうです。
- こちらの記事は Sequential モデルではなくて functional API をつかっています。
- ディスクリミネータの出力層がの次元が2になっています。つまり、本物なら [0, 1] 、模造品なら [1, 0] を目標出力とします。
- 2クラス分類では1次元にすることが多いと思っていました。どっちでもいいと思いますが。
- ただ、2クラス分類で出力層を2次元にする場合、判定を誤ったときに「正しいクラスであると考えた度合いが小さかった」のか、「正しいクラスであると考えた度合いは大きかったが、それ以上に誤ったクラスだと考えた度合いが大きかった」のかの区別は付くと思います(最終出力がソフトマックスされる前をみれば)。それをみたい機会があるのかはわかりませんが。
実行結果
ジェネレータとディスクリミネータをがっちゃんこした状態が以下です。
_________________________________________________________________ _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_3 (InputLayer) (None, 100) 0 _________________________________________________________________ model_1 (Model) (None, 28, 28, 1) 4341801 _________________________________________________________________ model_2 (Model) (None, 2) 9707266 ================================================================= Total params: 14,049,067 Trainable params: 13,970,367 Non-trainable params: 78,700 _________________________________________________________________
ただ肝心の訓練が上手くいっていないので、API はこのままに前回の記事のモデル構造を適用してみたいと思います。
スクリプト
元の記事は Keras のバージョンが手元より古いので、一部サポートされなくなってしまっていた機能がありました(BatchNormalization の mode=2)。
TypeError: The `mode` argument of `BatchNormalization` no longer exists. `mode=1` and `mode=2` are no longer supported.
Normalization Layers - Keras 1.2.2 Documentation
元の記事に比べて色々関数にくくり出しています。また、TensorFlow バックエンドの Keras 2.0 で動くように全体的に変更してあります。
# -*- coding: utf-8 -*- import numpy as np from keras.models import Model from keras.layers import Input from keras.layers.core import Reshape, Dense, Dropout, Activation, Flatten from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D, UpSampling2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.normalization import * from keras.activations import * from keras.optimizers import * from keras.datasets import mnist import matplotlib.pyplot as plt from tqdm import tqdm # ---------- ネットワークの訓練可能オンオフを制御する ---------- def make_trainable(net, trainable=False): net.trainable = trainable for l in net.layers: l.trainable = trainable # ---------- ジェネレータの生成 ---------- def create_generator(opt): g_input = Input(shape=[100]) H = Dense(14*14*200, init='glorot_normal')(g_input) H = BatchNormalization()(H) H = Activation('relu')(H) H = Reshape([14, 14, 200])(H) H = UpSampling2D(size=(2, 2))(H) H = Conv2D(100, (3, 3), padding='same', init='glorot_uniform')(H) H = BatchNormalization()(H) H = Activation('relu')(H) H = Conv2D(50, (3, 3), padding='same', init='glorot_uniform')(H) H = BatchNormalization()(H) H = Activation('relu')(H) H = Conv2D(1, (1, 1), padding='same', init='glorot_uniform')(H) g_V = Activation('sigmoid')(H) generator = Model(g_input, g_V) generator.compile(loss='binary_crossentropy', optimizer=opt) generator.summary() return generator # ---------- ディスクリミネータの生成 ---------- def create_discriminator(shp, dropout_rate, dopt): d_input = Input(shape=shp) H = Conv2D(256, (5, 5), subsample=(2, 2), padding='same', activation='relu')(d_input) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) H = Conv2D(512, (5, 5), subsample=(2, 2), padding='same', activation='relu')(H) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) H = Flatten()(H) H = Dense(256)(H) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) d_V = Dense(2, activation='softmax')(H) discriminator = Model(d_input,d_V) discriminator.compile(loss='categorical_crossentropy', optimizer=dopt) discriminator.summary() return discriminator # ---------- アドバーサリアルモデルの生成 ---------- def create_adversarial_model(generator, discriminator, opt): gan_input = Input(shape=[100]) H = generator(gan_input) gan_V = discriminator(H) GAN = Model(gan_input, gan_V) GAN.compile(loss='categorical_crossentropy', optimizer=opt) GAN.summary() return GAN # ---------- ディスクリミネータを1バッチ分トレーニングする ---------- def train_discriminator_1batch(discriminator, X_train, batch_size): image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size),:,:,:] noise_gen = np.random.uniform(0, 1, size=[batch_size, 100]) generated_images = generator.predict(noise_gen) X = np.concatenate((image_batch, generated_images)) y = np.zeros([2*batch_size, 2]) y[:batch_size, 1] = 1 # 本物データのときのディスクリミネータの期待出力は y=[0, 1] y[batch_size:, 0] = 1 # 模造データのときのディスクリミネータの期待出力は y=[1, 0] make_trainable(discriminator, True) d_loss = discriminator.train_on_batch(X, y) return d_loss # ---------- アドバーサリアルモデルを1バッチ分トレーニングする ---------- def train_GAN_1batch(discriminator, GAN, X_train, batch_size): noise_tr = np.random.uniform(0, 1, size=[batch_size, 100]) y = np.zeros([batch_size, 2]) y[:, 1] = 1 # 本物データと判断してほしいので期待出力は y=[0, 1] make_trainable(discriminator, False) # ディスクリミネータの重み係数更新は忘れずにオフ g_loss = GAN.train_on_batch(noise_tr, y) return g_loss # ---------- 損失をプロットする ---------- def plot_loss(losses, filename='loss.png'): plt.figure(figsize=(10,8)) plt.plot(losses["d"], label='discriminitive loss') plt.plot(losses["g"], label='generative loss') plt.legend() plt.savefig(filename) plt.close('all') # ---------- ジェネレータが出力する模造データをプロットする ---------- def plot_gen(noise, generator, filename='result.png'): generated_images = generator.predict(noise) plt.figure(figsize=(10,10)) for i in range(generated_images.shape[0]): plt.subplot(4, 4, i+1) img = generated_images[i,:,:,:] img = np.reshape(img, [28, 28]) plt.imshow(img, cmap='gray') plt.axis('off') plt.tight_layout() plt.savefig(filename) plt.close('all') # ---------- ネットワークをトレーニングする(メイン) ---------- def train_GAN(generator, discriminator, GAN, X_train, batch_size=32, steps=50): # ディスクリミネータの事前学習 for i in range(5): train_discriminator_1batch(discriminator, X_train, batch_size=10) losses = {"d":[], "g":[]} # 損失の記録用 noise_for_plot = np.random.uniform(0, 1, size=[16, 100]) # 途中経過出力用のノイズ for i in tqdm(range(steps)): # ディスクリミネータの学習 d_loss = train_discriminator_1batch(discriminator, X_train, batch_size) losses["d"].append(d_loss) # アドバーサリアルモデルの学習 g_loss = train_GAN_1batch(discriminator, GAN, X_train, batch_size) losses["g"].append(g_loss) # プロット if i%25 == 25-1: plot_loss(losses) plot_gen(noise_for_plot, generator, "result_%d.png" % i) # ========================= メイン処理 ========================= if __name__ == '__main__': # データ読み込みとプレ処理 img_rows, img_cols = 28, 28 (X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train[np.where(y_train == 1)] # 「1」のみにしぼる X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1).astype('float32') X_train /= 255.0 shp = X_train.shape[1:] # 訓練設定 dropout_rate = 0.25 dopt = Adam(lr=1e-3) opt = Adam(lr=1e-4) # ネットワークの生成 generator = create_generator(opt) discriminator = create_discriminator(shp, dropout_rate, dopt) GAN = create_adversarial_model(generator, discriminator, opt) # ネットワークの訓練 train_GAN(generator, discriminator, GAN, X_train)
その他
tqdm という進捗バーを表示してくれるパッケージをはじめて知りました。
Pythonで進捗バーを表示する(tqdm) - naritoブログ