使用 Keras 和 Tensorflow 构建辅助 GAN
先决条件:生成对抗网络
本文将演示如何使用 Keras 和 TensorFlow 库构建辅助生成对抗网络。使用的数据集是预加载到 Keras 中的MNIST Image 数据集。
第 1 步:设置环境
Step 1 : Open Anaconda prompt in Administrator mode.
Step 2 : Create a virtual environment using the command : conda create --name acgan python=3.7
Step 3 : Then, activate the environment using the command : conda activate acgan
Step 4 : Install the following libraries -
4.1 - Tensorflow --> pip install tensorflow==2.1
4.2 - Keras --> pip install keras==2.3.1
第 2 步:导入所需的库
Python3
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
Python3
# Defining the Input shape
image_shape = (28, 28, 1)
classes = 10
latent_dim = 100
# Defining the optimizer and the losses
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy','sparse_categorical_crossentropy']
Python3
def build_generator():
model = Sequential()
# Building the input layer
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.82))
model.add(UpSampling2D())
model.add(Conv2D(128, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(Conv2D(1, (3,3), padding='same'))
model.add(Activation("tanh"))
# Generating the output image
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,), dtype='int32')
z = Flatten()(Embedding(classes, latent_dim)(label))
model_input = multiply([noise, z])
image = model(model_input)
return Model([noise, label], image)
Python3
def build_discriminator():
model = Sequential()
# Building the input layer
model.add(Conv2D(16, (3,3), strides=2, input_shape=image_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, (3,3), strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, (3,3), strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, (3,3), strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
image = Input(shape=image_shape)
# Extract features from images
features = model(image)
# Building the output layer
validity = Dense(1, activation="sigmoid")(features)
label = Dense(classes, activation="softmax")(features)
return Model(image, [validity, label])
Python3
def display_images():
r = 10
c = 10
noise = np.random.normal(0, 1, (r * c,latent_dim))
new_labels = np.array([num for _ in range(r) for num in range(c)])
gen_images = generator.predict([noise, new_labels])
# Rescale images 0 - 1
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(r, c)
count = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_images[count,:,:,0], cmap='gray')
axs[i,j].axis('off')
count += 1
plt.show()
plt.close()
Python3
def train_acgan(epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X, y), (_, _) = mnist.load_data()
# Configure inputs
X = X.astype(np.float32)
X = (X - 127.5) / 127.5
X = np.expand_dims(X, axis=3)
y = y.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Select a random batch of images
index = np.random.randint(0, X.shape[0], batch_size)
images = X[index]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# The labels of the digits that the generator tries to create an
# image representation of
new_labels = np.random.randint(0, 10, (batch_size, 1))
# Generate a half batch of new images
gen_images = generator.predict([noise, new_labels])
image_labels = y[index]
# Training the discriminator
disc_loss_real = discriminator.train_on_batch(
images, [valid, image_labels])
disc_loss_fake = discriminator.train_on_batch(
gen_images, [fake, new_labels])
disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
# Training the generator
gen_loss = combined.train_on_batch(
[noise, new_labels], [valid, new_labels])
# Print the accuracies
print ("%d [acc.: %.2f%%, op_acc: %.2f%%]" % (
epoch, 100 * disc_loss[3], 100 * disc_loss[4]))
# display at every defined epoch interval
if epoch % sample_interval == 0:
display_images()
Python3
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=losses,
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
generator = build_generator()
# Defining the input for the generator
#and generating the images
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,))
image = generator([noise, label])
# Disable the Discriminator
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes in the generated image
# as input and determines validity
# and the label of that image
valid, target_label = discriminator(image)
# The combined model (both generator and discriminator)
# Training the generator to fool the discriminator
combined = Model([noise, label], [valid, target_label])
combined.compile(loss=losses, optimizer=optimizer)
train_acgan(epochs=14000, batch_size=32, sample_interval=2000)
第 3 步:定义要在后续流程中使用的参数
蟒蛇3
# Defining the Input shape
image_shape = (28, 28, 1)
classes = 10
latent_dim = 100
# Defining the optimizer and the losses
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy','sparse_categorical_crossentropy']
第 4 步:定义一个效用函数来构建生成器
蟒蛇3
def build_generator():
model = Sequential()
# Building the input layer
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.82))
model.add(UpSampling2D())
model.add(Conv2D(128, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(Conv2D(1, (3,3), padding='same'))
model.add(Activation("tanh"))
# Generating the output image
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,), dtype='int32')
z = Flatten()(Embedding(classes, latent_dim)(label))
model_input = multiply([noise, z])
image = model(model_input)
return Model([noise, label], image)
第 5 步:定义一个效用函数来构建鉴别器
蟒蛇3
def build_discriminator():
model = Sequential()
# Building the input layer
model.add(Conv2D(16, (3,3), strides=2, input_shape=image_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, (3,3), strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, (3,3), strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, (3,3), strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
image = Input(shape=image_shape)
# Extract features from images
features = model(image)
# Building the output layer
validity = Dense(1, activation="sigmoid")(features)
label = Dense(classes, activation="softmax")(features)
return Model(image, [validity, label])
第 6 步:定义一个效用函数来显示生成的图像
蟒蛇3
def display_images():
r = 10
c = 10
noise = np.random.normal(0, 1, (r * c,latent_dim))
new_labels = np.array([num for _ in range(r) for num in range(c)])
gen_images = generator.predict([noise, new_labels])
# Rescale images 0 - 1
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(r, c)
count = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_images[count,:,:,0], cmap='gray')
axs[i,j].axis('off')
count += 1
plt.show()
plt.close()
第 7 步:构建和训练 AC-GAN
蟒蛇3
def train_acgan(epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X, y), (_, _) = mnist.load_data()
# Configure inputs
X = X.astype(np.float32)
X = (X - 127.5) / 127.5
X = np.expand_dims(X, axis=3)
y = y.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Select a random batch of images
index = np.random.randint(0, X.shape[0], batch_size)
images = X[index]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# The labels of the digits that the generator tries to create an
# image representation of
new_labels = np.random.randint(0, 10, (batch_size, 1))
# Generate a half batch of new images
gen_images = generator.predict([noise, new_labels])
image_labels = y[index]
# Training the discriminator
disc_loss_real = discriminator.train_on_batch(
images, [valid, image_labels])
disc_loss_fake = discriminator.train_on_batch(
gen_images, [fake, new_labels])
disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
# Training the generator
gen_loss = combined.train_on_batch(
[noise, new_labels], [valid, new_labels])
# Print the accuracies
print ("%d [acc.: %.2f%%, op_acc: %.2f%%]" % (
epoch, 100 * disc_loss[3], 100 * disc_loss[4]))
# display at every defined epoch interval
if epoch % sample_interval == 0:
display_images()
第 8 步:构建生成对抗网络
蟒蛇3
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=losses,
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
generator = build_generator()
# Defining the input for the generator
#and generating the images
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,))
image = generator([noise, label])
# Disable the Discriminator
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes in the generated image
# as input and determines validity
# and the label of that image
valid, target_label = discriminator(image)
# The combined model (both generator and discriminator)
# Training the generator to fool the discriminator
combined = Model([noise, label], [valid, target_label])
combined.compile(loss=losses, optimizer=optimizer)
train_acgan(epochs=14000, batch_size=32, sample_interval=2000)
输出(每 2000 个纪元间隔):
时代 0
时代2000
时代4000
时代6000
纪元8000
纪元10000
纪元12000
纪元14000
通过视觉观察生成图像的进展,可以得出结论,网络正在以可接受的水平工作。可以通过训练网络更多时间或调整网络参数来提高图像质量。如有任何疑问/疑问,请在下方评论。