生成对抗网络

生成对抗网络

在这篇文章中,我们构建了生成对抗网络( generative adversarial network (GAN)),所用的数据集是MNIST,我们将会生成新的手写体数字。文章来源与udacityGAN是由 Ian Goodfellow于2014年提出的,你可以阅读论文原文,GAN自提出以来就变得非常流行,这儿还有一些其他的成果:

GAN背后的原理是你必须有两个网络,一个生成网络G,一个判别网络D,这两个网络互相竞争。生成网络主要是产生假的数据,然后传到判别网络,判别网络接收真的数据和假的数据,然后对他们进行判断,哪些是真的,哪些是假的。为了欺骗判别网络,那么生成网络就会进行训练,它尽可能生成数据,让判别网络无法区分这生成的数据是假的。判别网络为了区分出真假数据也会进行不断训练。最终,判别网络无法区分出生成网络生成的假数据。

GAN diagram 上图显示的是GAN的一般架构,使用的数据集是MNIST,生成网络生成随机的数据构建图片,然后不断训练,欺骗判别网络。判别网络输出只有0和1,0表示一张假的图片,1表示一张真的图片。

%matplotlib inline

import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz

模型输入

首先我们需要创建两个输入,一个用于生成器inputs_z,一个用于判别器inouts_real

def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim) , name ='input_real')
    inputs_z = tf.placeholder(tf.float32, (None, z_dim) ,name='input_z')
    
    return inputs_real, inputs_z

生成网络

GAN Network

为了更具有一般性,我们使用了一个隐藏层,激活函数使用了leaky ReLU,反向传播时可以流过该层。

变量范围

我们使用tf.variable_scope是为了区分开生成器的所有变量与判别器的所有变量,这有利于后面对两个网络进行分开训练。

我们可以使用 tf.name_scope来设置变量的名字,关键字reuse可以告诉TensorFlow图可以重复使用这些变量,而不是重新创建新的变量。

with tf.variable_scope('scope_name', reuse=False):
   

the TensorFlow documentation variable_scope

Leaky ReLU

TensorFlow并没有直接提供leaky ReLU激活函数,所以我们得自己创建。leaky ReLU对于大于0的直接输出,小于0的输出一个alpha * x的值。

Tanh Output

生成器使用的输出层激活函数Tanh的性能会更好,所以我们需要对图像数据缩放到-1到1,而不是0到1.

def generator(z, out_dim, n_units=128, reuse=False,  alpha=0.01):
    ''' Build the generator network.
    
        Arguments
        ---------
        z : Input tensor for the generator
        out_dim : Shape of the generator output
        n_units : Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out: 
    '''
    with tf.variable_scope('generator', reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(h1, h1 * alpha)
        
        # Logits and tanh output
        logits = tf.layers.dense(h1, out_dim, activation=None)
        out = tf.tanh(logits)
        
        return out

判别器

判别器与生成器的网络架构很相似,只是输出层不一样,因为判别器输出是0和1,所以使用的激活函数是sigmoid

def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    ''' Build the discriminator network.
    
        Arguments
        ---------
        x : Input tensor for the discriminator
        n_units: Number of units in hidden layer
        reuse : Reuse the variables with tf.variable_scope
        alpha : leak parameter for leaky ReLU
        
        Returns
        -------
        out, logits: 
    '''
    with tf.variable_scope('discriminator',reuse=reuse): # finish this
        # Hidden layer
        h1 = tf.layers.dense(x, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(h1, h1 * alpha)
        
        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)
        
        return out, logits

超参数

# Size of input image to discriminator
input_size = 784 # 28x28 MNIST images flattened
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Label smoothing 
smooth = 0.1

构建网络

现在我们可以使用上面的函数来构建我们的网络了,从 model_inputs得到输入input_real, input_z,然后创建生成器generator(input_z, input_size),然后创建判别器,注意这儿,我们需要创建两个判别器,一个接收真的数据。一个接收假的数据,但是他们使用的权重是一样的,所以reuse=True

tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)

# Generator network here
g_model = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
# g_model is the generator output

# Disriminator network here
d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True , n_units = d_hidden_size, alpha=alpha)

生成器与判别器的损失

我们需要计算函数的损失,对于判别器的损失由来个来源,一个是真的数据产生的损失,一个是假的数据产生的损失, d_loss = d_loss_real + d_loss_fake.

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))

logits就是判别器的logits的输出值,labels对于真的数据我们希望是1,对于假的数据我们希望是0。为了让网络具有更好的泛化能力,我们进行了平滑操作 labels = tf.ones_like(tensor) * (1 - smooth)

对于生成器的损失,logits是判别器的假的数据logits输出值,因为生成器本来产生的就是假数据,但是labels的值为1,因为我们希望生成器生成的家数据是真数据,用于欺骗判别器。

# Calculate losses

d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_real,
                                                                     labels = tf.ones_like(d_logits_real) * (1 - smooth)))

d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                                     labels = tf.zeros_like(d_logits_fake)))

d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_logits_fake,
                                                                     labels = tf.ones_like(d_logits_fake)))

优化

生成器的参数与判别器的参数的训练我们需要分开进行。各自降低两个网络的损失。

# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

Training

batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1 #scale to -1,1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = g_loss.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)
Epoch 1/100... Discriminator Loss: 0.3665... Generator Loss: 3.5270
Epoch 2/100... Discriminator Loss: 0.3890... Generator Loss: 3.8497
Epoch 3/100... Discriminator Loss: 0.3980... Generator Loss: 4.3148
Epoch 4/100... Discriminator Loss: 2.5614... Generator Loss: 3.9421
Epoch 5/100... Discriminator Loss: 0.6711... Generator Loss: 4.3231
Epoch 6/100... Discriminator Loss: 1.6751... Generator Loss: 1.2985
Epoch 7/100... Discriminator Loss: 0.9973... Generator Loss: 1.7754
Epoch 8/100... Discriminator Loss: 2.1265... Generator Loss: 1.9723
Epoch 9/100... Discriminator Loss: 1.1451... Generator Loss: 1.9348
Epoch 10/100... Discriminator Loss: 1.7930... Generator Loss: 3.1384
Epoch 11/100... Discriminator Loss: 1.7409... Generator Loss: 1.8317
Epoch 12/100... Discriminator Loss: 2.7053... Generator Loss: 1.4636
Epoch 13/100... Discriminator Loss: 1.0681... Generator Loss: 3.5390
Epoch 14/100... Discriminator Loss: 1.6029... Generator Loss: 1.2599
Epoch 15/100... Discriminator Loss: 1.2468... Generator Loss: 1.6186
Epoch 16/100... Discriminator Loss: 1.0815... Generator Loss: 1.8498
Epoch 17/100... Discriminator Loss: 0.9844... Generator Loss: 1.2992
Epoch 18/100... Discriminator Loss: 0.9230... Generator Loss: 2.1901
Epoch 19/100... Discriminator Loss: 0.8258... Generator Loss: 3.2655
Epoch 20/100... Discriminator Loss: 0.8584... Generator Loss: 2.1915
Epoch 21/100... Discriminator Loss: 0.6526... Generator Loss: 2.5069
Epoch 22/100... Discriminator Loss: 0.8377... Generator Loss: 2.4397
Epoch 23/100... Discriminator Loss: 0.7968... Generator Loss: 2.3466
Epoch 24/100... Discriminator Loss: 0.9963... Generator Loss: 2.2792
Epoch 25/100... Discriminator Loss: 1.1942... Generator Loss: 1.7112
Epoch 26/100... Discriminator Loss: 0.8386... Generator Loss: 1.6728
Epoch 27/100... Discriminator Loss: 0.9020... Generator Loss: 2.4640
Epoch 28/100... Discriminator Loss: 0.9796... Generator Loss: 2.1753
Epoch 29/100... Discriminator Loss: 0.9199... Generator Loss: 2.5527
Epoch 30/100... Discriminator Loss: 1.0509... Generator Loss: 1.8442
Epoch 31/100... Discriminator Loss: 0.7627... Generator Loss: 3.0345
Epoch 32/100... Discriminator Loss: 1.1164... Generator Loss: 1.5489
Epoch 33/100... Discriminator Loss: 0.8950... Generator Loss: 2.3809
Epoch 34/100... Discriminator Loss: 1.1739... Generator Loss: 2.0350
Epoch 35/100... Discriminator Loss: 0.8261... Generator Loss: 2.5511
Epoch 36/100... Discriminator Loss: 1.1356... Generator Loss: 1.9123
Epoch 37/100... Discriminator Loss: 0.7850... Generator Loss: 2.0336
Epoch 38/100... Discriminator Loss: 1.3770... Generator Loss: 1.8081
Epoch 39/100... Discriminator Loss: 1.3912... Generator Loss: 1.3591
Epoch 40/100... Discriminator Loss: 1.1284... Generator Loss: 1.4507
Epoch 41/100... Discriminator Loss: 0.8900... Generator Loss: 1.8905
Epoch 42/100... Discriminator Loss: 0.9838... Generator Loss: 1.8828
Epoch 43/100... Discriminator Loss: 0.9963... Generator Loss: 1.9200
Epoch 44/100... Discriminator Loss: 1.4808... Generator Loss: 1.5501
Epoch 45/100... Discriminator Loss: 0.9672... Generator Loss: 1.8741
Epoch 46/100... Discriminator Loss: 1.0026... Generator Loss: 1.8940
Epoch 47/100... Discriminator Loss: 1.0286... Generator Loss: 1.5216
Epoch 48/100... Discriminator Loss: 1.2535... Generator Loss: 1.4766
Epoch 49/100... Discriminator Loss: 1.0187... Generator Loss: 1.3388
Epoch 50/100... Discriminator Loss: 1.0499... Generator Loss: 1.6826
Epoch 51/100... Discriminator Loss: 1.0148... Generator Loss: 1.9221
Epoch 52/100... Discriminator Loss: 1.1113... Generator Loss: 1.9696
Epoch 53/100... Discriminator Loss: 0.9164... Generator Loss: 1.9217
Epoch 54/100... Discriminator Loss: 1.0550... Generator Loss: 1.7317
Epoch 55/100... Discriminator Loss: 1.3500... Generator Loss: 1.1638
Epoch 56/100... Discriminator Loss: 0.9911... Generator Loss: 1.5880
Epoch 57/100... Discriminator Loss: 0.8094... Generator Loss: 2.0780
Epoch 58/100... Discriminator Loss: 0.9349... Generator Loss: 1.9257
Epoch 59/100... Discriminator Loss: 0.8955... Generator Loss: 2.5756
Epoch 60/100... Discriminator Loss: 0.8442... Generator Loss: 2.3068
Epoch 61/100... Discriminator Loss: 0.8275... Generator Loss: 2.0744
Epoch 62/100... Discriminator Loss: 0.9792... Generator Loss: 1.4697
Epoch 63/100... Discriminator Loss: 1.1776... Generator Loss: 1.3126
Epoch 64/100... Discriminator Loss: 0.8305... Generator Loss: 2.4137
Epoch 65/100... Discriminator Loss: 0.9313... Generator Loss: 2.0773
Epoch 66/100... Discriminator Loss: 0.8766... Generator Loss: 3.3415
Epoch 67/100... Discriminator Loss: 0.8594... Generator Loss: 2.1336
Epoch 68/100... Discriminator Loss: 1.1959... Generator Loss: 1.6913
Epoch 69/100... Discriminator Loss: 1.0146... Generator Loss: 2.2956
Epoch 70/100... Discriminator Loss: 0.9191... Generator Loss: 2.0312
Epoch 71/100... Discriminator Loss: 0.9018... Generator Loss: 2.0220
Epoch 72/100... Discriminator Loss: 1.1151... Generator Loss: 1.4409
Epoch 73/100... Discriminator Loss: 1.0630... Generator Loss: 1.5294
Epoch 74/100... Discriminator Loss: 0.9483... Generator Loss: 1.8939
Epoch 75/100... Discriminator Loss: 1.0796... Generator Loss: 1.4881
Epoch 76/100... Discriminator Loss: 1.0368... Generator Loss: 1.5860
Epoch 77/100... Discriminator Loss: 0.9743... Generator Loss: 1.7110
Epoch 78/100... Discriminator Loss: 0.9401... Generator Loss: 1.9215
Epoch 79/100... Discriminator Loss: 1.1534... Generator Loss: 1.6998
Epoch 80/100... Discriminator Loss: 0.9709... Generator Loss: 1.7046
Epoch 81/100... Discriminator Loss: 0.8844... Generator Loss: 1.7949
Epoch 82/100... Discriminator Loss: 1.0132... Generator Loss: 2.1153
Epoch 83/100... Discriminator Loss: 0.9493... Generator Loss: 1.6879
Epoch 84/100... Discriminator Loss: 0.9492... Generator Loss: 1.5327
Epoch 85/100... Discriminator Loss: 1.1477... Generator Loss: 1.4398
Epoch 86/100... Discriminator Loss: 0.9071... Generator Loss: 1.9815
Epoch 87/100... Discriminator Loss: 1.1913... Generator Loss: 1.4262
Epoch 88/100... Discriminator Loss: 1.0941... Generator Loss: 1.8383
Epoch 89/100... Discriminator Loss: 0.8030... Generator Loss: 1.9573
Epoch 90/100... Discriminator Loss: 1.0562... Generator Loss: 2.1593
Epoch 91/100... Discriminator Loss: 1.0917... Generator Loss: 1.8404
Epoch 92/100... Discriminator Loss: 1.0762... Generator Loss: 1.3095
Epoch 93/100... Discriminator Loss: 0.9516... Generator Loss: 1.8261
Epoch 94/100... Discriminator Loss: 0.9504... Generator Loss: 2.1033
Epoch 95/100... Discriminator Loss: 1.0374... Generator Loss: 1.8829
Epoch 96/100... Discriminator Loss: 0.8951... Generator Loss: 1.9297
Epoch 97/100... Discriminator Loss: 1.1448... Generator Loss: 2.1090
Epoch 98/100... Discriminator Loss: 0.9450... Generator Loss: 1.6973
Epoch 99/100... Discriminator Loss: 1.1092... Generator Loss: 1.5661
Epoch 100/100... Discriminator Loss: 0.9818... Generator Loss: 1.9050

训练损失

参看训练过程中,生成器与判别器的损失。

%matplotlib inline

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()
<matplotlib.legend.Legend at 0x11c72f28>

png

参看训练中的样本

def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pkl.load(f)
_ = view_samples(-1, samples)

png

生成器在训练过程中产出的样本

rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

png

刚开始生成的都是噪声,后面逐渐改善了,可以产生一些数字图片。

从生成器中采样

从刚才训练好的生成器中,我们随机生成数字,看看结果如何

saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_z = np.random.uniform(-1, 1, size=(16, z_size))
    gen_samples = sess.run(
                   generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                   feed_dict={input_z: sample_z})
view_samples(0, [gen_samples])
(<matplotlib.figure.Figure at 0x11ae25c0>,
 array([[<matplotlib.axes._subplots.AxesSubplot object at 0x0000000014904BE0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x00000000148216A0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000014338630>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x00000000143148D0>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x00000000143220F0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000019F5E940>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000019F716D8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000014B6D470>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x0000000014BE01D0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000014BFC630>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000015062390>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x00000000150870F0>],
        [<matplotlib.axes._subplots.AxesSubplot object at 0x0000000015093DD8>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000014C05C18>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000019E53EF0>,
         <matplotlib.axes._subplots.AxesSubplot object at 0x0000000019E79DA0>]], dtype=object))

png