📜  生成对抗网络(GAN)

📅  最后修改于: 2021-04-16 08:17:49             🧑  作者: Mango

生成对抗网络(GAN)是一类功能强大的神经网络,可用于无监督学习。它是由Ian J. Goodfellow在2014年开发和引入的。GAN基本上是由两个相互竞争的神经网络模型组成的系统,彼此竞争,并且能够分析,捕获和复制数据集中的变化。

为什么首先开发GAN?
已经注意到,大多数主流神经网络都可以通过仅向原始数据中添加少量噪声而轻易地将其分类为错误的事物。出人意料的是,添加噪声后的模型比正确预测的模型对错误预测的置信度更高。造成这种敌意的原因是,大多数机器学习模型都是从有限的数据中学习的,这是一个巨大的缺点,因为它很容易过度拟合。同样,输入和输出之间的映射几乎是线性的。尽管看起来各个类之间的分隔边界似乎是线性的,但实际上,它们是由线性组成的,即使特征空间中某个点的微小变化也可能导致数据分类错误。

GAN如何工作?

生成对抗网络(GAN)可以分为三个部分:

  • 生成式:学习生成式模型,该模型描述了如何根据概率模型生成数据。
  • 对抗性:模型的训练是在对抗性环境中进行的。
  • 网络:使用深度神经网络作为用于训练目的的人工智能(AI)算法。

在GAN中,有一个生成器和一个鉴别器。生成器生成伪造的数据样本(例如图像,音频等),并试图欺骗鉴别器。另一方面,鉴别器试图区分真实样本和假样本。生成器和鉴别器都是神经网络,在训练阶段它们都相互竞争。这些步骤被重复了几次,在这种情况下,每次重复后,生成器和鉴别器在各自的工作中会变得越来越好。可以通过以下图表可视化工作:
gansgfg

在此,生成模型捕获数据的分布,并以某种方式对其进行训练,使其试图使鉴别器犯错的可能性最大化。另一方面,鉴别器基于一个模型,该模型估计从训练数据而不是从生成器接收到的样本的概率。
GAN被定义为一个最小极大游戏,其中鉴别者试图使奖励V(D,G)最小化,而生成者试图使鉴别者的奖励最小化,换句话说,使损失最大化。可以通过以下公式在数学上进行描述:

甘草

在哪里,
G =发电机
D =鉴别符
Pdata(x)=实际数据的分布
P(z)=发电机的分布
x =来自Pdata(x)的样本
z =来自P(z)的样本
D(x)=鉴别器网络
G(z)=发电机网络

因此,基本上,训练GAN包含两个部分:

  • 第1部分:鉴别器在发电机空闲时接受训练。在此阶段,仅对网络进行正向传播,而不会进行反向传播。鉴别器针对n个历元的真实数据进行了训练,看看它是否可以正确地将其预测为真实数。同样,在此阶段,还对鉴别器进行了有关生成器生成的伪造数据的培训,以查看其是否可以正确地将其预测为伪造。
  • 第2部分:鉴别器空闲时训练生成器。在使用生成器生成的伪造数据对鉴别器进行训练之后,我们可以获取其预测,并使用结果训练生成器,并从先前的状态中获得更好的结果来欺骗鉴别器。

    重复上述方法几个纪元,然后手动检查虚假数据,看似真实。如果看起来可以接受,则停止训练,否则,允许训练再继续几个纪元。

    GAN的不同类型:
    GAN现在是一个非常活跃的研究主题,并且有许多不同类型的GAN实施。下面介绍了当前正在积极使用的一些重要功能:

    1. Vanilla GAN:这是最简单的GAN类型。在这里,生成器和鉴别器是简单的多层感知器。在普通GAN中,该算法非常简单,它尝试使用随机梯度下降来优化数学方程。
    2. 条件GAN(CGAN): CGAN可以描述为一种深度学习方法,其中将一些条件参数放在适当的位置。在CGAN中,附加参数y被添加到Generator中以生成相应的数据。标签也被放入鉴别器的输入中,以使鉴别器帮助区分真实数据和伪造数据。
    3. 深度卷积GAN(DCGAN): DCGAN是GAN中最受欢迎,最成功的实现之一。它由ConvNets代替多层感知器组成。 ConvNets是在没有最大池化的情况下实现的,实际上已被卷积跨度所取代。此外,各层未完全连接。
    4. 拉普拉斯金字塔GAN(LAPGAN):拉普拉斯金字塔是一种线性可逆图像表示,由一组以八度为间隔的带通图像以及一个低频残差组成。这种方法使用多个生成器和鉴别器网络以及不同级别的拉普拉斯金字塔。主要使用此方法,因为它会产生非常高质量的图像。首先在金字塔的每一层对图像进行降采样,然后在向后遍历中在每一层再次进行放大,在此过程中,图像会从这些层的条件GAN中获取一些噪声,直到达到其原始大小为止。
    5. 超分辨率GAN(SRGAN):顾名思义,SRGAN是一种设计GAN的方法,其中将深层神经网络与对抗网络一起使用,以生成更高分辨率的图像。这种类型的GAN在以最佳方式放大原始低分辨率图像时特别有用,以增强其细节,从而最大程度地减少错误。

    实现通用对抗网络的示例Python代码:
    GAN在计算上非常昂贵。他们需要高性能的GPU和大量时间(大量时间)才能产生良好的结果。对于我们的示例,我们将使用著名的MNIST数据集并将其用于生成随机数字的克隆。

    # importing the necessary libraries and the MNIST dataset
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data
      
    mnist = input_data.read_data_sets("MNIST_data")
      
    # defining functions for the two networks.
    # Both the networks have two hidden layers
    # and an output layer which are densely or 
    # fully connected layers defining the 
    # Generator network function
    def generator(z, reuse = None):
        with tf.variable_scope('gen', reuse = reuse):
            hidden1 = tf.layers.dense(inputs = z, units = 128, 
                                activation = tf.nn.leaky_relu)
                                  
            hidden2 = tf.layers.dense(inputs = hidden1,
               units = 128, activation = tf.nn.leaky_relu)
                 
            output = tf.layers.dense(inputs = hidden2, 
                 units = 784, activation = tf.nn.tanh)
              
            return output
      
    # defining the Discriminator network function 
    def discriminator(X, reuse = None):
        with tf.variable_scope('dis', reuse = reuse):
            hidden1 = tf.layers.dense(inputs = X, units = 128,
                                activation = tf.nn.leaky_relu)
                                  
            hidden2 = tf.layers.dense(inputs = hidden1,
                   units = 128, activation = tf.nn.leaky_relu)
                     
            logits = tf.layers.dense(hidden2, units = 1)
            output = tf.sigmoid(logits)
              
            return output, logits
      
    # creating placeholders for the outputs
    tf.reset_default_graph()
      
    real_images = tf.placeholder(tf.float32, shape =[None, 784])
    z = tf.placeholder(tf.float32, shape =[None, 100])
      
    G = generator(z)
    D_output_real, D_logits_real = discriminator(real_images)
    D_output_fake, D_logits_fake = discriminator(G, reuse = True)
      
    # defining the loss function
    def loss_func(logits_in, labels_in):
        return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                              logits = logits_in, labels = labels_in))
      
     # Smoothing for generalization
    D_real_loss = loss_func(D_logits_real, tf.ones_like(D_logits_real)*0.9)
    D_fake_loss = loss_func(D_logits_fake, tf.zeros_like(D_logits_real))
    D_loss = D_real_loss + D_fake_loss
      
    G_loss = loss_func(D_logits_fake, tf.ones_like(D_logits_fake))
      
    # defining the learning rate, batch size,
    # number of epochs and using the Adam optimizer
    lr = 0.001 # learning rate
      
    # Do this when multiple networks
    # interact with each other
      
    # returns all variables created(the two
    # variable scopes) and makes trainable true
    tvars = tf.trainable_variables() 
    d_vars =[var for var in tvars if 'dis' in var.name]
    g_vars =[var for var in tvars if 'gen' in var.name]
      
    D_trainer = tf.train.AdamOptimizer(lr).minimize(D_loss, var_list = d_vars)
    G_trainer = tf.train.AdamOptimizer(lr).minimize(G_loss, var_list = g_vars)
      
    batch_size = 100 # batch size
    epochs = 500 # number of epochs. The higher the better the result
    init = tf.global_variables_initializer()
      
    # creating a session to train the networks
    samples =[] # generator examples
      
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(epochs):
            num_batches = mnist.train.num_examples//batch_size
              
            for i in range(num_batches):
                batch = mnist.train.next_batch(batch_size)
                batch_images = batch[0].reshape((batch_size, 784))
                batch_images = batch_images * 2-1
                batch_z = np.random.uniform(-1, 1, size =(batch_size, 100))
                _= sess.run(D_trainer, feed_dict ={real_images:batch_images, z:batch_z})
                _= sess.run(G_trainer, feed_dict ={z:batch_z})
                  
            print("on epoch{}".format(epoch))
              
            sample_z = np.random.uniform(-1, 1, size =(1, 100))
            gen_sample = sess.run(generator(z, reuse = True),
                                     feed_dict ={z:sample_z})
              
            samples.append(gen_sample)
      
    # result after 0th epoch
    plt.imshow(samples[0].reshape(28, 28))
      
    # result after 499th epoch
    plt.imshow(samples[49].reshape(28, 28))
    

    输出:

    on epoch0
    on epoch1
    ...
    ...
    ...
    on epoch498
    on epoch499
    

    第0个时期后的结果:
    epoch_zero
    第499个时期后的恢复:
    时代

    因此,从上面的示例中,我们可以看到在第0个历元之后的第一个图像中,像素散布在整个位置上,我们无法从中找出任何东西。
    但是从第二张图像中,我们可以看到像素被系统地组织了起来,并且可以发现代码随机选择的是数字“ 7”,并且网络已尝试对其进行克隆。在我们的示例中,我们以500个为纪元。但是您可以增加该数字以进一步优化结果。