Generative Adversarial Network

I wrestled with this for a while and got some kind of results, but nowhere near the great performance for which GANs are famous. Ultimately, they're absurdly sensitive to hyperparameters and initialization, and if you don't exactly imitate the published settings, you are unlikely to get good results.

I figure I should post my attempt -- maybe the community can figure out a good set of parameters that works. Mine sort of trained, but suffered from mode collapse and often converged on a blob. However, for sets of all one digit, it did seem to work okay, although this is much easier and not really where GANs have any advantage.

I tried to implement the Wasserstein GAN (paper available here) to generate MNIST digits. The training procedure is to update the discriminator on a batch 5 times for every 1 time you show it to the generator. Because Mathematica doesn't yet allow preservation of optimizer parameters between calls to NetTrain, I couldn't get this to work. Instead, I trained the networks jointly as suggested by Taliesin Beynon, setting the learning rate on the generator to something like -1/5, because it seemed like a plausible approximation.

The paper also used RMSProp as an optimizer. Mathematica has an RMSProp option, but on the net I defined it immediately diverged no matter what learning rate I chose. I used ADAM instead.

To begin, let's get a big batch of MNIST digits.

mnist = ResourceData["MNIST"];
mnistDigits = First /@ mnist;

Let's give a 10-dimensional noise input to the generator, and define the generator and discriminator. Notice that the discriminator does not have an activation on its output -- this is specific to the WGAN, a normal GAN would have a LogisticSigmoid or something.

randomDim = 10;
generator = 
 NetChain[{128, Ramp, 128, Ramp, 28*28, LogisticSigmoid, 
   ReshapeLayer[{1, 28, 28}]}, "Input" -> randomDim]
discriminator = 
 NetChain[{128, Ramp, 128, Ramp, 128, Ramp, 1}, 
  "Input" -> {1, 28, 28}]

Now the tricky part. We'll feed noise into the generator to produce a fake image, and also accept a real image as input. We want to apply the discriminator to both images, but with one set of weights, so we concatenate them and use NetMapOperator. Then, the loss function should be to maximize the score on the real image while minimizing the score on the fake image, so we negate the real score and then add them.

wganNet = 
 NetInitialize[
  NetGraph[<|"gen" -> generator, 
    "discrimop" -> NetMapOperator[discriminator], 
    "cat" -> CatenateLayer[], 
    "reshape" -> ReshapeLayer[{2, 1, 28, 28}], 
    "flat" -> FlattenLayer[], "total" -> SummationLayer[], 
    "scale" -> 
     ConstantTimesLayer["Scaling" -> {-1, 1}]|>, {NetPort["random"] ->
      "gen" -> "cat", NetPort["Input"] -> "cat", 
    "cat" -> 
     "reshape" -> "discrimop" -> "flat" -> "scale" -> "total"}, 
   "Input" -> {1, 28, 28}]]

One of the strengths of Mathematica's neural networks framework is that it's really easy to watch the networks train. We'll feed the trainer a progress function that takes 4 fixed random inputs and shows the generator's output, so we can watch the generator evolve over time.

ClearAll[progressFuncCreator]
progressFuncCreator[rands_List] := 
 Function[{reals}, 
    ImageResize[
     NetDecoder[{"Image", "Grayscale"}][
      NetExtract[#Net, "gen"][reals]], 50]] /@ rands &

Finally, create the training data:

trainingData = <|"random" -> RandomReal[{-1, 1}, {randomDim}], 
     "Input" -> ArrayReshape[ImageData[#], {1, 28, 28}]|> & /@ 
   mnistDigits;

And train, watching the generator make a bunch of vaguely number-shaped blobs. Notice the "WeightClipping" option on the discriminator -- this is the "secret sauce" in Wasserstein GANs that makes them learn an approximation of the Wasserstein/Earth-Mover's distance as opposed to the Jensen-Shannon distance, as explained in the paper.

NetTrain[wganNet, trainingData, "Output", 
 Method -> {"ADAM", "Beta1" -> 0.5, "LearningRate" -> 0.00005, 
   "WeightClipping" -> {"discrimop" -> 0.01}},
 TrainingProgressReporting -> 
  progressFuncCreator[Table[RandomReal[{-1, 1}, {randomDim}], 4]], 
 LearningRateMultipliers -> {"scale" -> 0, "gen" -> -0.2}, 
 TargetDevice -> "GPU", BatchSize -> 64]

Overall, my impression of the neural networks framework is very good. It's extremely flexible, coherently designed, and also extremely pretty. Crucially, it's easier to watch your net train than in any other framework. However, due to difficulties with staged training/saving optimizer parameters, it's not yet possible to replicate (in the sense of replicating a scientific experiment) some published results, like GANs, that use weirder architectures.


Yes it is possible. You can do alternating training manually by literally following the algorithm, so that you have a Do loop whose body contains two calls to NetTrain, but that suffers from overhead at each alternation (this could be overcome with clever caching, but we haven't done that yet). An approximation of this is to build a single network and optimize the D and G losses simultaneously by using a negative learning rate for the generator.

I have prototyped this, but only on a toy example.

I encourage you to try how to do it, it didn't take us more than a few hours of playing around to make a simple GAN in which the data distribution is a gaussian, the discriminator is an MLP, and the generator is a single EmbeddingLayer (just a fixed set of samples that can be moved around by gradient updates).


I see the generation of images from classification models as a trivial idea, largely, because this kind of processes and algorithms are standard in natural language processing.

More to the point of the question, the following posts show generation of images with classification derived bases:

  • an answer to "Code that generates a mandala"

  • "Comparison of dimension reduction algorithms over mandala images generation"

In some sense when using SVD or NNMF bases to recognize an image we reconstruct it by appropriate overlaying of basis images. Obviously such overlaying can be done without a recognition goal just to generate new images.

Update, 2017-06-24

Looking at the answer of Michael Curry and running the code (and Taliesin Beynon's code) I kind of see using Neural Networks as some sort of a long route. The MNIST based images those codes generate can be generated in a much quicker and controllable way using SVD and NNMF.

As an example, examine this basis images of handwriting "5" obtained with NNMF:

enter image description here

With those kind of bases are generated the (re)constructed handwritten digit images from this MSE answer:

enter image description here

(The linked answer describes the generation procedure.)