Using Moments to Stabilize Generative Adversarial Network (GAN) Learning
Generative Adversarial Networks (GANs), created in 2014 by Ian Goodfellow, are an extremely promising method of producing fake data indistinguishable from the real thing by pitting two neural networks against one another. According to Open AI, GANs currently produce the sharpest generative images, compared to the other popular methods: Variational Autoencoders and Autoregressive models. This benefit however, comes at a cost, GANs are difficult to optimize, due to unstable training dynamics (Karpathy, June 16, 2016). GANs also have two neural networks, which must be synchronized well or the generative model will collapse around a successful instance (a generated instance that can fool the discriminator) as opposed to approximating the true distribution of the real dataset (Goodfellow, June 10, 2014). Finally, GANs can be very sensitive to the initial values of the weights and fail to train, batch normalization is recommended to help overcome this issue (Udacity, May 5, 2017). However, there may be another way. What if we adjusted the generative loss function to penalize the model if it doesn’t produce a similar distribution to the real data?
Dependencies:
Python 3.5.1
Tensorflow 1.0.1
Numpy
Matplotlib
Pickle
Pandas
Tests:
In all tests, We optimize with AdamOptimizer and use the default learning rate. We create our “fake MNIST dataset” from an initial input of 100 dimensions drawn from a uniform distribution with a minimum of -1 and a maximum of 1. Throughout testing we will use a batch size of 100 and we train for 100 epochs. We do so using numpy as follows:
z_size = 100
batch_size = 100
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
Baseline Test: One Layer Generator without batch normalization.
We use a generator with one hidden layer of 128 units (n_units). We use a leaky relu activation function, which is designed to fix the “dying relu problem” (more information can be found at:
http://cs231n.github.io/neural-networks-1/).
Results from the baseline GAN, (from every 10 epochs, all other visualizations follow this format as well):
Discriminator Loss vs Generator Loss (Baseline)
We can see that the generative model learns to produce numeric-like figures. However, they don’t look like “real numbers”. We could probably get better results from a deeper network however, there’s a problem. As mentioned earlier, if you attempt to train the network without using a technique such as batch-normalization it won't train well.
Test 2: Two Hidden Layer Generator without batch normalization.
We use a generator with two layers, the first layer has 128 hidden units and the second layer has 384 hidden units. After 100 epochs, the model starts to produce figures that somewhat resemble numeric figures, but these are clearly worse results than the single hidden layer model. Perhaps better results could be attained by training past 100 epochs, but we are using the same epochs for every model to maintain consistency.
Results from Two Hidden Layer GAN (no batch normalization)
Discriminator Loss vs Generator Loss (Two Hidden Layers)
Test #3: Two Hidden Layer Generator with Batch Normalization
We now introduce batch normalization to attempt to get better results from the deeper network. The only difference between this test and the prior test is the introduction of batch normalization on each of the hidden layers. Clearly, batch normalization has a positive effect on the generator’s performance vs the base two layer generator, and the model seems to learn how to create numeric like figures, but they are nothing to write home about. We’ll try one more test (before utilizing the “real data moments of central measure” and extend the network to four hidden layers.
Results from Two Hidden Layer GAN (with Batch Normalization)
Discriminator Loss vs Generator Loss (Two Hidden Layers/Batch Normalization)
Test #4: Four Hidden Layer Generator with Batch Normalization
The four hidden layer generator with batch normalization does not perform well. It seems to be getting better at producing numeric like figures at 100 epochs, but again, we are using the same number of epochs to compare each model to maintain consistency.
Result from Four Hidden Layer GAN (with Batch Normalization)
Discriminator Loss vs Generator Loss (Four Hidden Layers/Batch Normalization)
A Solution? Stabilizing with Moments from the Real Data’s Distribution
Indeed GANs seem to be hard to train, Github user Soumith Chintala has compiled a group of “hacks” to help train them taken from NIPS 2016:
https://github.com/soumith/ganhacks. However, there may be another way. If the goal of the generator is to reproduce the distribution of the real data, why not add something to the loss function to penalize the generator for collapsing or for not conforming to the real data’s distribution. With tensorflow’s moment function (
tf.nn.moments) we can simply measure the difference between the generator’s mean and variance and the real data’s mean and variance on a batch-by-batch basis for each feature. We can do so as follows:
g_mean, g_var = tf.nn.moments(g_model, axes=[0])
d_mean, d_var = tf.nn.moments(input_real, axes=[0])
mean_diff = 0.1 * tf.reduce_sum(tf.abs(g_mean - d_mean))
std_diff = 0.1 * tf.reduce_sum(tf.abs(g_var - d_var))
We scale by 0.1 to keep the mean_diff and the std_diff comparable to the generator loss, we don’t want these measures to be so much larger than the generator loss that the model “ignores” the generator loss.
The generator loss goes from:
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)))
to:
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake))) + (std_diff + mean_diff)
Test #5: Two Hidden Layer Generator with Moment Stabilization
As we can see, the generator now produces much better results. We can now clearly see that the generator is producing numeric like figures, much crisp and clear than with batch normalization.
Results from Two Hidden Layer GAN (with Moment Stabilization)
Discriminator Loss vs Generator Loss (Two Hidden Layers/Moment Stabilization)
Test #6: Four Hidden Layer Generator with Moment Stabilization
By stabilizing with the moments of central measure of the real data we are now able to successfully create deep generator networks.
Results from Four Hidden Layer GAN (with Moment Stabilization)
Discriminator Loss vs Generator Loss (Four Hidden Layers/Moment Stabilization)
Test #7: Four Hidden Layer Generator with Moment Stabilization, Dropout
& Batch Normalization
As a final test, we see that we can also reap the rewards of moment stabilization with other regularization methods. This model is trained with dropout in the first layer and batch normalization.
Results from Four Hidden Layer GAN (with Moment Stabilization, Dropout & Batch Normalization)
Discriminator Loss vs Generator Loss (Four Hidden Layers/Moment Stabilization/Dropout/Batch Normalization)
Benefits and Concerns (initial):
1. The model seems to learn faster and more effectively with "moment stabilization".
2. How will the method transfer to datasets that are clustered and that doesn't lend itself well to being described by a single mean and variance?
3. The method adds to model complexity since the mean and variance are calculated for each dimension of the dataset over each batch. This doesn't pose too much of an issue for datasets such as MNIST with only 784 features, but what about datasets with 10,000 or 1,000,000 features or more features?
Next Steps:
1. Try the method on another dataset
2. Try the method with convolution