使用 Keras 构建生成对抗网络


本文将演示如何使用 Keras 库构建生成对抗网络。使用的数据集是预加载到 Keras 中的 CIFAR10 图像数据集。您可以在此处阅读有关数据集的信息。

第 1 步:导入所需的库

import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam,SGD

第 2 步:加载数据

#Loading the CIFAR10 data
(X, y), (_, _) = keras.datasets.cifar10.load_data()
#Selecting a single class images
#The number was randomly chosen and any number
#between 1 to 10 can be chosen
X = X[y.flatten() == 8]

第 3 步:定义要在后续流程中使用的参数

#Defining the Input shape
image_shape = (32, 32, 3)
latent_dimensions = 100

第 4 步:定义一个实用函数来构建生成器

def build_generator():
        model = Sequential()
        #Building the input layer
        model.add(Dense(128 * 8 * 8, activation="relu",
        model.add(Reshape((8, 8, 128)))
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Conv2D(3, kernel_size=3, padding="same"))
        #Generating the output image
        noise = Input(shape=(latent_dimensions,))
        image = model(noise)
        return Model(noise, image)

第 5 步:定义一个实用函数来构建鉴别器

def build_discriminator():
        #Building the convolutional layers
        #to classify whether an image is real or fake
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, strides=2,
                         input_shape=image_shape, padding="same"))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        #Building the output layer
        model.add(Dense(1, activation='sigmoid'))
        image = Input(shape=image_shape)
        validity = model(image)
        return Model(image, validity)

第 6 步:定义实用程序函数以显示生成的图像

def display_images():
        r, c = 4,4
        noise = np.random.normal(0, 1, (r * c,latent_dimensions))
        generated_images = generator.predict(noise)
        #Scaling the generated images
        generated_images = 0.5 * generated_images + 0.5
        fig, axs = plt.subplots(r, c)
        count = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(generated_images[count, :,:,])
                count += 1

第 7 步:构建生成对抗网络

# Building and compiling the discriminator
discriminator = build_discriminator()
#Making the Discriminator untrainable
#so that the generator can learn from fixed gradient
discriminator.trainable = False
# Building the generator
generator = build_generator()
#Defining the input for the generator
#and generating the images
z = Input(shape=(latent_dimensions,))
image = generator(z)
#Checking the validity of the generated image
valid = discriminator(image)
#Defining the combined model of the Generator and the Discriminator
combined_network = Model(z, valid)

第 8 步:训练网络

#Normalizing the input
X = (X / 127.5) - 1.
#Defining the Adversarial ground truths
valid = np.ones((batch_size, 1))
#Adding some noise 
valid += 0.05 * np.random.random(valid.shape)
fake = np.zeros((batch_size, 1))
fake += 0.05 * np.random.random(fake.shape)
for epoch in range(num_epochs):
            #Training the Discriminator
            #Sampling a random half of images
            index = np.random.randint(0, X.shape[0], batch_size)
            images = X[index]
            #Sampling noise and generating a batch of new images
            noise = np.random.normal(0, 1, (batch_size, latent_dimensions))
            generated_images = generator.predict(noise)
            #Training the discriminator to detect more accurately
            #whether a generated image is real or fake
            discm_loss_real = discriminator.train_on_batch(images, valid)
            discm_loss_fake = discriminator.train_on_batch(generated_images, fake)
            discm_loss = 0.5 * np.add(discm_loss_real, discm_loss_fake)
            #Training the Generator
            #Training the generator to generate images
            #which pass the authenticity test
            genr_loss = combined_network.train_on_batch(noise, valid)
            #Tracking the progress                
            if epoch % display_interval == 0:

纪元 0:

纪元 2500:




纪元 12500:


第 8 步:评估性能



#Plotting some of the original images 
s = 0.5 * s + 0.5
f, ax = plt.subplots(5,8, figsize=(16,10))
for i, image in enumerate(s):
    ax[i//8, i%8].imshow(image)
    ax[i//8, i%8].axis('off')

b)绘制上一个 epoch 生成的图像

#Plotting some of the last batch of generated images
noise = np.random.normal(size=(40, latent_dimensions))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
f, ax = plt.subplots(5,8, figsize=(16,10))
for i, image in enumerate(generated_images):
    ax[i//8, i%8].imshow(image)
    ax[i//8, i%8].axis('off')
