fbpx

Туториал: создание простой GAN на Python с библиотекой Keras

gan python keras tutorial

В этом туториале я расскажу о генеративно-состязательных нейронных сетях (GAN) не прибегая к математическим деталям модели. Далее будет показано, как написать собственную простую GAN на Python с Keras, которая сможет генерировать знаки.

Перед вам перевод статьи Demystifying Generative Adversarial Nets (GANs), опубликованной на Datacamp, автор — Stefan Hosien. Ссылка на оригинал — в подвале статьи.

Аналогия

Проще всего понять, что такое GAN, обратившись к следующей аналогии. Представьте, что есть магазин, который покупает определенные сорта вина у своих поставщиков, которые он затем будет перепродавать.

gan туториал

Есть нечестные поставщики, которые продают поддельное вино, чтобы получить деньги. В таком случае руководство магазина должно уметь различать поддельные и подлинные вина.

gan на python туториал

Можно предположить, что изначально мошенники могли сделать много ошибок при попытке продать поддельное вино, а руководство магазина с легкостью определяло фальшивые экземпляры. Методом проб и ошибок мошенники пробовали разные техники, чтобы имитировать подлинное вино, и в конечном счете им это удалось. Теперь когда мошенники знают, как сделать так, чтобы вино прошло контроль в магазине, они начинают дальше улучшать свой продукт.

В то же время руководство магазина может получать фидбэк от других магазинов или экспертов о том, что некоторые их вина неоригинальные. Поэтому магазину приходится улучшать свои методи определения поддельных вин. Цель мошенников — создание неотличимых от оригинала вин, руководство магазина же стремится точно определить подлинность вина.

Такое взаимное состязание является идеей, лежащей в основе GAN.

Архитектура генеративно-состязательной сети

Используя пример, о котором было сказано выше, можно прийти к архитектуре GAN.

как работает gan

Очевидно, что в GAN должны быть две основные части — генератор и, так называемый, дискриминатор. Руководство магазина в примере сверху — дискриминаторная сеть, которая обычно представляет из себя сверточную нейросеть, CNN, (так как сети GAN в основном используются для задач, связанных с изображениями), которая приписывает изображению процент соответствия подлинности.

Мошенником в GAN выступает генеративная сеть, которая также является сверточной сетью со слоем развертки (deconvolution layer). Эта сеть накладывает шум на изображение (использую вектор шума) и выводит его. Во время тренировки генеративная сеть изучает, какие области изображения необходимо изменить или улучшить, чтобы дискриминатору понадобилось больше времени для определения подлинности сгенерированного изображения.

Генеративная сеть с каждым разом производит изображение, которое все больше походит на реальное, в то время как дискриминативная сеть пытается найти различия между реальным и искусственным изображением. Главная цель — сделать такую генеративную сеть, которая сможет воспроизводить неотличимые от реальных изображения.

Простая генеративно-состязательная сеть в Keras

Теперь когда вы поняли, что такое GAN,  какие компоненты у нее есть, начнем писать код. Будем использовать Keras, а если вы не знакомы с этим фреймворком Python, перед началом работы посмотрите этот туториал. В основе этого туториала лежит простая и понятная GAN, разработанная здесь.

Для начала вам необходимо с помощью pip установить следующие пакеты:

- keras
- matplotlib
- tensorflow
- tqdm

Мы будем использовать matplotlib для отрисовки графиков, tensorflow в качестве необходимого для Keras бэкграунда, tqdm для красивой визуализации прогресса с каждой эпохой, итерацией.

Следующий шаг — создание скрипта на Python. В этом скрипте сначала необходимо импортировать все модули и функции для работы. Объяснение работы каждого модуля будет дано позже.

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

Теперь определим некоторые переменные:

# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100

Перед тем как начать строить дискриминатор и генератор, нужно собрать данные и сделать их предварительную обработку. Будем использовать известный датасет MNIST, который представляет из себя набор изображений цифр от 0 до 9.

генерация символов с gan
Пример символов из датасета MNIST
def load_minst_data():
    # load the data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    # normalize our inputs to be in the range[-1, 1] 
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
    # 784 columns per row
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)

Заметим, что команда mnist.load_data() является частью Keras и позволяет легко импортировать датасет в рабочее пространство.

Теперь мы можем создать сети генератора и дискриминатора. Для обеих сетей используем оптимизатор Adam. В обоих случаях сеть будет состоять из трех скрытых слоев с активационной функцией Leaky Relu. Также следует добавить в дискриминатор dropout слои, чтобы улучшить его надежность, качество (robustness) на изображениях, которые не были показаны.

def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5)

def get_generator(optimizer):
    generator = Sequential()
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

def get_discriminator(optimizer):
    discriminator = Sequential()
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

Осталась только соединить генератор с дискриминатором.

def get_gan_network(discriminator, random_dim, generator, optimizer):
    # We initially set trainable to False since we only want to train either the 
    # generator or discriminator at a time
    discriminator.trainable = False
    # gan input (noise) will be 100-dimensional vectors
    gan_input = Input(shape=(random_dim,))
    # the output of the generator (an image)
    x = generator(gan_input)
    # get the output of the discriminator (probability if the image is real or not)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=optimizer)
    return gan

Дополнительно можно создать функцию, сохраняющую сгенерированные изображения через каждые 20 эпох. Так как этот шаг не является основным в туториале, вам не обязательно полностью понимать выводящую изображение функцию.

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, random_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

Мы написали большую часть нашей сети. Осталось только обучить нейросеть и посмотреть на результаты — изображения.

def train(epochs=1, batch_size=128):
    # Get the training and testing data
    x_train, y_train, x_test, y_test = load_minst_data()
    # Split the training data into batches of size 128
    batch_count = x_train.shape[0] / batch_size

    # Build our GAN netowrk
    adam = get_optimizer()
    generator = get_generator(adam)
    discriminator = get_discriminator(adam)
    gan = get_gan_network(discriminator, random_dim, generator, adam)

    for e in xrange(1, epochs+1):
        print '-'*15, 'Epoch %d' % e, '-'*15
        for _ in tqdm(xrange(batch_count)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

            # Generate fake MNIST images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])

            # Labels for generated and real data
            y_dis = np.zeros(2*batch_size)
            # One-sided label smoothing
            y_dis[:batch_size] = 0.9

            # Train discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 20 == 0:
            plot_generated_images(e, generator)

if __name__ == '__main__':
    train(400, 128)

После обучения на 400 эпохах, можем посмотреть сгенерированные изображения. Глядя на произведенные после первой эпохи изображения, вы можете заметить, что они не имеют реальную структуру. После 40 эпох изображения приобретают нужную форму, а после 400 эпох изображения четкие и почти неотличимые от настоящих, за исключением пары штук.

gan на python туториал генерация цифр
Результат после 1 эпохи
результат работы генеративной сети
Результат после 40 эпох
результат gan
Результат после 400 эпох

Главная причина, по которой был выбран этот код, это скорость выполнения. Во время тренировки на CPU для каждой эпохи требуется примерно 2 минуты. Вы можете сами поэкспериментировать с кодом, добавляя эпохи или слои (не обязательно такие же) в генератор и дискриминатор. Однако, если вы работаете с CPU, использование более сложных и глубоких архитектур потребует большего времени на тренировку. Но этот факт не должен останавливать, экспериментируйте!

Заключение

Поздравляю, вы дошли до конца туториала и получили интуитивное понимание генеративно-состязательных сетей GAN. Помимо понимания вы реализовали свою собственную сеть с помощью библиотеки Keras.