元の記事は Keras のバージョンが手元より古いので、一部サポートされなくなってしまっていた機能がありました(BatchNormalization の mode=2)。
TypeError: The `mode` argument of `BatchNormalization` no longer exists. `mode=1` and `mode=2` are no longer supported.
Normalization Layers - Keras 1.2.2 Documentation
元の記事に比べて色々関数にくくり出しています。また、TensorFlow バックエンドの Keras 2.0 で動くように全体的に変更してあります。
import numpy as np
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Reshape, Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import *
from keras.activations import *
from keras.optimizers import *
from keras.datasets import mnist
import matplotlib.pyplot as plt
from tqdm import tqdm
def make_trainable(net, trainable=False):
net.trainable = trainable
for l in net.layers:
l.trainable = trainable
def create_generator(opt):
g_input = Input(shape=[100])
H = Dense(14*14*200, init='glorot_normal')(g_input)
H = BatchNormalization()(H)
H = Activation('relu')(H)
H = Reshape([14, 14, 200])(H)
H = UpSampling2D(size=(2, 2))(H)
H = Conv2D(100, (3, 3), padding='same', init='glorot_uniform')(H)
H = BatchNormalization()(H)
H = Activation('relu')(H)
H = Conv2D(50, (3, 3), padding='same', init='glorot_uniform')(H)
H = BatchNormalization()(H)
H = Activation('relu')(H)
H = Conv2D(1, (1, 1), padding='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
generator = Model(g_input, g_V)
generator.compile(loss='binary_crossentropy', optimizer=opt)
generator.summary()
return generator
def create_discriminator(shp, dropout_rate, dopt):
d_input = Input(shape=shp)
H = Conv2D(256, (5, 5), subsample=(2, 2), padding='same', activation='relu')(d_input)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Conv2D(512, (5, 5), subsample=(2, 2), padding='same', activation='relu')(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Flatten()(H)
H = Dense(256)(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
d_V = Dense(2, activation='softmax')(H)
discriminator = Model(d_input,d_V)
discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
discriminator.summary()
return discriminator
def create_adversarial_model(generator, discriminator, opt):
gan_input = Input(shape=[100])
H = generator(gan_input)
gan_V = discriminator(H)
GAN = Model(gan_input, gan_V)
GAN.compile(loss='categorical_crossentropy', optimizer=opt)
GAN.summary()
return GAN
def train_discriminator_1batch(discriminator, X_train, batch_size):
image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size),:,:,:]
noise_gen = np.random.uniform(0, 1, size=[batch_size, 100])
generated_images = generator.predict(noise_gen)
X = np.concatenate((image_batch, generated_images))
y = np.zeros([2*batch_size, 2])
y[:batch_size, 1] = 1
y[batch_size:, 0] = 1
make_trainable(discriminator, True)
d_loss = discriminator.train_on_batch(X, y)
return d_loss
def train_GAN_1batch(discriminator, GAN, X_train, batch_size):
noise_tr = np.random.uniform(0, 1, size=[batch_size, 100])
y = np.zeros([batch_size, 2])
y[:, 1] = 1
make_trainable(discriminator, False)
g_loss = GAN.train_on_batch(noise_tr, y)
return g_loss
def plot_loss(losses, filename='loss.png'):
plt.figure(figsize=(10,8))
plt.plot(losses["d"], label='discriminitive loss')
plt.plot(losses["g"], label='generative loss')
plt.legend()
plt.savefig(filename)
plt.close('all')
def plot_gen(noise, generator, filename='result.png'):
generated_images = generator.predict(noise)
plt.figure(figsize=(10,10))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i+1)
img = generated_images[i,:,:,:]
img = np.reshape(img, [28, 28])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(filename)
plt.close('all')
def train_GAN(generator, discriminator, GAN, X_train, batch_size=32, steps=50):
for i in range(5):
train_discriminator_1batch(discriminator, X_train, batch_size=10)
losses = {"d":[], "g":[]}
noise_for_plot = np.random.uniform(0, 1, size=[16, 100])
for i in tqdm(range(steps)):
d_loss = train_discriminator_1batch(discriminator, X_train, batch_size)
losses["d"].append(d_loss)
g_loss = train_GAN_1batch(discriminator, GAN, X_train, batch_size)
losses["g"].append(g_loss)
if i%25 == 25-1:
plot_loss(losses)
plot_gen(noise_for_plot, generator, "result_%d.png" % i)
if __name__ == '__main__':
img_rows, img_cols = 28, 28
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train[np.where(y_train == 1)]
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1).astype('float32')
X_train /= 255.0
shp = X_train.shape[1:]
dropout_rate = 0.25
dopt = Adam(lr=1e-3)
opt = Adam(lr=1e-4)
generator = create_generator(opt)
discriminator = create_discriminator(shp, dropout_rate, dopt)
GAN = create_adversarial_model(generator, discriminator, opt)
train_GAN(generator, discriminator, GAN, X_train)