Generative Adversarial Networks

We give a brief introduction to GANs and code up a simple GAN as described in the original paper by Ian Goodfellow.

In 2014, Ian Goodfellow, who is currently working as Director of Machine Learning for Apple, published a paper called "Generative Adversarial Networks" or GAN for short, which talked about a system of two neural networks, Generator and Discriminator, that can generate images or any data that is similar to provided datasets from essentially random noise.

The diagram below shows the basic architecture of a GAN.


Generative Network takes some random noise and outputs some random noise. This output noise is passed to a discriminator along with the real image or data as a ground truth based on that both Discriminator and Generator are trained.

As you can see, the concept of GAN is very simple. When I started learning about GANs I thought I can easily implement one of those. But boy I was wrong. In reality training a GAN is extremely hard both Generator and Discriminator must be trained side by side if one overpowers the other it won't work, we will talk about all the problems you might face while training a GAN. But now let's look at a simple GAN in Pytorch. We will be using the MNIST Dataset for this post.

Understanding the Dataset

The MNIST dataset is a huge database of handwritten numbers from 0 to 9 used for Optical Character Recognition or reading numbers from an image.

This dataset consists of 28x28 images of handwritten numbers where each pixel contains either a zero or a one.

The computer vision extension of PyTorch, Torchvision, provides this dataset which we can download using the following code snippet.

mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)

But before we can run this code, we need to import some libraries.

import torchvision.datasets as datasets
from import DataLoader
import torchvision.transforms as transforms

You can install the torchvision library using this pip command.

pip install torchvision

The DataLoader class is used to load the data to memory in batches, this prevents your system from running out of memory while training.

Transforms class is used to make random augmentations to the image such as random rotation, resize, crop, etc. but for now we will only normalise the images to range from -1 to 1.

The entire function that you can copy-paste is this.

def load_dataset():
    transformations = transforms.Compose([
        transforms.Normalize((0.5, ), (0.5, ))
    mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)
    return DataLoader(mnist, batch_size=32, shuffle=True)

Implementing the Generator

In this section, we will implement a generator network that will take a random noise vector of size 100 and convert it into a vector of size 784 which we will then convert to 28x28

class Generator(nn.Module):
    def __init__(self):
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.Linear(256, 784),
            nn.Tanh(),  # make outputs [-1, 1]

    def forward(self, x):
        return self.gen(x)

We will be implementing the original GAN created by Ian Goodfellow in this paper

Implementing the Discriminator

Next, we will implement a discriminator that will take the vector of size 784 that may be generated or from a real image.

class Discriminator(nn.Module):
    def __init__(self):
        self.disc = nn.Sequential(
            nn.Linear(784, 128),
            nn.Linear(128, 1),
            nn.Sigmoid(), # make outputs [0, 1]

    def forward(self, x):
        return self.disc(x)


As I mentioned earlier, GANs are extremely hard to train, one of the reasons is that a GAN is very sensitive to the initial values or hyperparameters. You have to follow the original papers to get the training right. Otherwise, you might have to spend many days optimizing the hyperparameters.

config = {
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'lr': 3e-4,
    'epochs': 50,
        'batch_size': 32

In our case, we will be using the same parameters as it was said in the paper, that is, a learning rate of 0.0003 and 50 epochs.

What is a learning rate?

The simple answer, it's the rate at which a machine-learning model learns. Smaller the number, the slower it learns and the higher the number the faster it learns. For more info, check out our post about Perceptrons.

Training Time

Finally, it's time to generate some handwritten numbers, that is, train our GAN. First let's create the objects for Generator, Discriminator, and their optimizers, we will be using the Adam optimizers.

ADAM or ADAptive Moment optimizer is an algorithm used to update the weights such that the overall error goes down.

disc = Discriminator().to(config['device'])
gen = Generator().to(config['device'])

optimiser_g = optim.Adam(params=gen.parameters(), lr=config['lr'])
optimiser_d = optim.Adam(params=disc.parameters(), lr=config['lr'])

We also need to define a loss function before we train, we will be using the Binary Cross Entropy loss function to calculate the error of our model.

Binary Cross Entropy (BCE) Loss

This loss formula is used to calculate the distance between two probability distributions.

$$BCE = ℓ(x,y)=L=(l_1​,…,l_N​), l_n​=−w_n​[y_n​⋅logx_n​+(1−y_n​)⋅log(1−x_n​)]$$

loss_fn = nn.BCELoss()

Now we will write our training step which is one cycle of generating and discriminating a handwritten number compared with an original number and update the models.

for epoch in range(config['epoch']):
    for batch_idx, (real, label) in enumerate(train_data):

        noise = torch.randn(config['batch_size'], 100).to(config['device']) # Create a random probability distribution
        fake = gen(noise) # Generate a fake number
        disc_real = disc(real).view(-1) # pass the real number through the discriminator
        lossD_real = loss_fn(disc_real, torch.ones_like(disc_real)) # calculate the loss for real image
        disc_fake = disc(fake).view(-1) # pass the fake number through the discriminator
        lossD_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake)) # calculate the loss for fake image
        lossD = (lossD_real + lossD_fake) / 2 # calculate the average loss

        # update the weights for the discriminator

        output = disc(fake).view(-1)
        lossG = loss_fn(output, torch.ones_like(output)) # calculate the error between the fake image and the true image

        # update the weights for the generator

This code is heavily inspired by a youtube video by Aladdin Persson.

Find the complete code at our GitHub Repo

If you have any questions, feel free to comment below.

