Туториал по созданию системы фильтров Snapchat с использованием Deep Learning

Snapchat

Добро пожаловать всем программистам, которые, возможно, открыли эту статью, увидев слова «Snapchat» и «Deep Learning». Клянусь, эти 2 слова привлекают вас, ребята, как огонь мотылька. Чего уж там, я тоже стал их жертвой, поэтому я часами делал этот проект.

Перевод статьи Creating the Snapchat Filter System using Deep Learning, автор — Rohit Agrawal, ссылка на оригинал — в подвале статьи.

В этой статье я расскажу о процессе и немного о теории проекта из названия. Раскрою все карты — когда я использовал термин «Snapchat» в заголовке, я немного схитрил, потому что, хотя этот проект работает по тому же принципу (использование ключевых точек лица для сопоставления объектов лицу), он сам по себе никак не связан с реализацией подобного функционала в Snapchat с точки зрения сложности и точности. После этого позвольте мне начать с представления датасета, который я использовал.

Датасет

Датасет с ключевыми точками на лицах

Я использовал следующий датасет: https://www.kaggle.com/c/facial-keypoints-detection, предоставленный доктором Йошуа Бенжио из Университета Монреаля.

Каждая прогнозируемая ключевая точка задается парой действительных чисел (x, y) в пространстве индексов пикселей. Есть 15 ключевых точек, которые представляют различные элементы лица. Входное изображение задается в последнем поле файлов данных и состоит из списка пикселей (упорядоченных по строкам) в виде целых чисел в диапазоне (0,255). Изображения имеют размер 96х96 пикселей.

Теперь, когда у нас есть хорошее представление о представлении данных, с которыми мы имеем дело, нам нужно предварительно обработать их, чтобы мы могли использовать их в качестве входных данных для нашей модели.

Шаг 1: Предварительная обработка данных и другие махинации

В приведенном выше датасете есть два файла, к которым мы должны обратиться -  training.csv и test.csv. Файл training содержит 31 столбец: 30 столбцов для координат ключевой точки и последний столбец, содержащий данные изображения в строковом формате. Он содержит 7049 объектов, однако многие из этих примеров имеют значения «NaN» в некоторых ключевых моментах, которые усложняют нам задачу. Поэтому мы будем рассматривать только образцы без каких-либо значений NaN. Вот код, который делает именно это: (Следующий код также нормализует данные изображений и ключевых точек, что является очень распространенным этапом предобработки данных)

# Check if row has any NaN values 
def has_nan(keypoints):
    for i in range(len(keypoints)):
        if math.isnan(keypoints[i]):
            return True
    return False

# Read the data as Dataframes
training = pd.read_csv('data/training.csv')
test = pd.read_csv('data/test.csv')

# Get training data
imgs_train = []
points_train = []
for i in range(len(training)):
    points = training.iloc[i,:-1]
    if has_nan(points) is False:
        test_image = training.iloc[i,-1]        # Get the image data
        test_image = np.array(test_image.split(' ')).astype(int)    
        test_image = np.reshape(test_image, (96,96))        # Reshape into an array of size 96x96
        test_image = test_image/255         # Normalize image
        imgs_train.append(test_image)
        
        keypoints = training.iloc[i,:-1].astype(int).values
        keypoints = keypoints/96 - 0.5  # Normalize keypoint coordinates
        points_train.append(keypoints)

imgs_train = np.array(imgs_train)    
points_train = np.array(points_train)

# Get test data
imgs_test = []
for i in range(len(test)):
    test_image = test.iloc[i,-1]        # Get the image data
    test_image = np.array(test_image.split(' ')).astype(int)
    test_image = np.reshape(test_image, (96,96))        # Reshape into an array of size 96x96
    test_image = test_image/255     # Normalize image
    imgs_test.append(test_image)
    
imgs_test = np.array(imgs_test)

Все хорошо? Не на самом деле нет. Похоже, что было только 2140 образцов, которые не содержали значений NaN. Это гораздо меньше образцов, чем требуется для обучения обобщенной и точной модели. Таким образом, чтобы создать больше данных, нам нужно дополнить наши текущие данные.

Аугментация данных — это методика, в основном используемая для генерации большего количества данных из существующих данных с использованием таких методов, как масштабирование, перемещение, вращение и т.д. В моем случае я отразил каждое изображение и соответствующие ему ключевые точки, поскольку такие методы, как масштабирование и поворот, могли бы исказить изображения лица и, таким образом, испортили бы модель. Наконец, я объединил исходные данные с новыми дополненными данными, чтобы получить в общей сложности 4280 образцов.

# Data Augmentation by mirroring the images
def augment(img, points):
    f_img = img[:, ::-1]        # Mirror the image
    for i in range(0,len(points),2):        # Mirror the key point coordinates
        x_renorm = (points[i]+0.5)*96       # Denormalize x-coordinate
        dx = x_renorm - 48          # Get distance to midpoint
        x_renorm_flipped = x_renorm - 2*dx      
        points[i] = x_renorm_flipped/96 - 0.5       # Normalize x-coordinate
    return f_img, points

aug_imgs_train = []
aug_points_train = []
for i, img in enumerate(imgs_train):
    f_img, f_points = augment(img, points_train[i])
    aug_imgs_train.append(f_img)
    aug_points_train.append(f_points)
    
aug_imgs_train = np.array(aug_imgs_train)
aug_points_train = np.array(aug_points_train)

# Combine the original data and augmented data
imgs_total = np.concatenate((imgs_train, aug_imgs_train), axis=0)       
points_total = np.concatenate((points_train, aug_points_train), axis=0)

Шаг 2: Архитектура модели и обучение

Теперь давайте погрузимся в раздел проекта «Глубокое обучение». Наша цель — предсказать значения координат для каждой ключевой точки, поэтому это задача регрессии. Поскольку мы работаем с изображениями, сверточные нейронные сети являются довольно очевидным выбором для извлечения признаков. Эти извлеченные признаки затем передаются в полносвязную нейронную сеть, которая выводит координаты. В конечном полносвязном слое нужно 30 нейронов, потому что нам нужно 30 значений (15 пар координат (x, y)).

# Define the architecture
def get_model():
    model = Sequential()
    model.add(Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(96,96,1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
    model.add(Conv2D(64, kernel_size=1, strides=2, padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), padding='same'))
    model.add(Flatten())
    
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.1))
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.1))
    model.add(Dense(30))
    return model;

def compile_model(model):       # Compile the model
    model.compile(loss='mean_absolute_error', optimizer='adam', metrics = ['accuracy'])

def train_model(model):         # Fit the model
    checkpoint = ModelCheckpoint(filepath='weights/checkpoint-{epoch:02d}.hdf5')
    model.fit(imgs_train, points_train, epochs=300, batch_size=100, callbacks=[checkpoint])

# Train the model
model = get_model()
compile_model(model)
train_model(model)
  • Активация «ReLu» используются после каждого сверточного и полносвязного слоя, за исключением последнего полносвязного слоя, так как это значения координат, которые нам нужны в качестве выходных данных.
  • Регуляризация Dropout используется для предотвращения переобучения
  • MaxPooling добавлен для уменьшения размерности

Модель смогла достичь минимальных потерь ~ 0,0113 и точности ~ 80% , что, на мой взгляд, было достаточно приличным. Вот несколько результатов производительности модели на тестовой выборке:

def test_model(model):    
    data_path = join('','*g')
    files = glob.glob(data_path)
    for i,f1 in enumerate(files):       # Test model performance on a screenshot for the webcam
        if f1 == 'Capture.PNG':
            img = imread(f1)
            img = rgb2gray(img)         # Convert RGB image to grayscale
            test_img = resize(img, (96,96))     # Resize to an array of size 96x96
    test_img = np.array(test_img)
    test_img_input = np.reshape(test_img, (1,96,96,1))      # Model takes input of shape = [batch_size, height, width, no. of channels]
    prediction = model.predict(test_img_input)      # shape = [batch_size, values]
    visualize_points(test_img, prediction[0])
    
    # Test on first 10 samples of the test set
    for i in range(len(imgs_test)):
        test_img_input = np.reshape(imgs_test[i], (1,96,96,1))      # Model takes input of shape = [batch_size, height, width, no. of channels]
        prediction = model.predict(test_img_input)      # shape = [batch_size, values]
        visualize_points(imgs_test[i], prediction[0])
        if i == 10:
            break

test_model(model)
Результат построения ключевых точек на лицах
Результаты модели на тестовой выборке

Мне также нужно было проверить производительность модели на изображении с моей веб-камеры, потому что это то, что модель получит во время реализации фильтра. Вот как модель работает на этом изображении моего красивого лица:

Не пугайтесь этого страшного лица. Я не кусаюсь

Шаг 3: Приведение модели в действие

Наша модель работает, поэтому все, что нам нужно сейчас сделать, это использовать OpenCV для выполнения следующих действий:

  1. Получить изображения с веб-камеры
  2. Определить область лица на каждом кадре изображения, потому что другие участки изображения бесполезны для модели (я использовал Каскад Хаара, чтобы обрезать область лица)
  3. Предобработать эту обрезанную область путем преобразования в оттенки серого, нормализации и изменения формы
  4. Передать предобработанное изображение в качестве входных данных для модели
  5. Получить прогнозы для ключевых точек и использовать их, чтобы расположить различные фильтры на лице

Когда я начал тестирование, я не имел в виду никаких конкретных фильтров. Идея проекта появилась у меня примерно 22 декабря 2018 года, и, как и любой другой обычный человек, я был большим поклонником Рождества и решил использовать следующие фильтры:

Фильтры для Snapchat

Я использовал определенные ключевые точки для масштабирования и позиционирования каждого из вышеуказанных фильтров:

  • Фильтр очков: расстояние между левой ключевой точкой левого глаза и правой ключевой точкой правого глаза используется для масштабирования. Точка для бровей и левая точка для левого глаза используются для позиционирования очков.
  • Фильтр бороды: расстояние между левой и правой точками губ используется для масштабирования. Верхняя ключевая точка губы и левая точка используются для позиционирования бороды.
  • Фильтр шляпы: Ширина лица используется для масштабирования. Ключевая точка бровей и левая точка левого глаза используются для позиционирования шляпы.

Код, который делает все вышеперечисленное, выглядит следующим образом:

# Implement the model in real-time 

# Importing the libraries
import numpy as np
from training import get_model, load_trained_model, compile_model
import cv2

# Load the trained model
model = get_model()
compile_model(model)
load_trained_model(model)

# Get frontal face haar cascade
face_cascade = cv2.CascadeClassifier('cascades/haarcascade_frontalface_default.xml')

# Get webcam
camera = cv2.VideoCapture(0)

# Run the program infinitely
while True:
    grab_trueorfalse, img = camera.read()       # Read data from the webcam
    
    # Preprocess input fram webcam
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)        # Convert RGB data to Grayscale
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)     # Identify faces in the webcam
    
    # For each detected face using tha Haar cascade
    for (x,y,w,h) in faces:
        roi_gray = gray[y:y+h, x:x+w]
        img_copy = np.copy(img)
        img_copy_1 = np.copy(img)
        roi_color = img_copy_1[y:y+h, x:x+w]
        
        width_original = roi_gray.shape[1]      # Width of region where face is detected
        height_original = roi_gray.shape[0]     # Height of region where face is detected
        img_gray = cv2.resize(roi_gray, (96, 96))       # Resize image to size 96x96
        img_gray = img_gray/255         # Normalize the image data
        
        img_model = np.reshape(img_gray, (1,96,96,1))   # Model takes input of shape = [batch_size, height, width, no. of channels]
        keypoints = model.predict(img_model)[0]         # Predict keypoints for the current input
        
        # Keypoints are saved as (x1, y1, x2, y2, ......)
        x_coords = keypoints[0::2]      # Read alternate elements starting from index 0
        y_coords = keypoints[1::2]      # Read alternate elements starting from index 1
        
        x_coords_denormalized = (x_coords+0.5)*width_original       # Denormalize x-coordinate
        y_coords_denormalized = (y_coords+0.5)*height_original      # Denormalize y-coordinate
        
        for i in range(len(x_coords)):          # Plot the keypoints at the x and y coordinates
            cv2.circle(roi_color, (x_coords_denormalized[i], y_coords_denormalized[i]), 2, (255,255,0), -1)
        
        # Particular keypoints for scaling and positioning of the filter
        left_lip_coords = (int(x_coords_denormalized[11]), int(y_coords_denormalized[11]))
        right_lip_coords = (int(x_coords_denormalized[12]), int(y_coords_denormalized[12]))
        top_lip_coords = (int(x_coords_denormalized[13]), int(y_coords_denormalized[13]))
        bottom_lip_coords = (int(x_coords_denormalized[14]), int(y_coords_denormalized[14]))
        left_eye_coords = (int(x_coords_denormalized[3]), int(y_coords_denormalized[3]))
        right_eye_coords = (int(x_coords_denormalized[5]), int(y_coords_denormalized[5]))
        brow_coords = (int(x_coords_denormalized[6]), int(y_coords_denormalized[6]))
        
        # Scale filter according to keypoint coordinates
        beard_width = right_lip_coords[0] - left_lip_coords[0]
        glasses_width = right_eye_coords[0] - left_eye_coords[0]
        
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2BGRA)       # Used for transparency overlay of filter using the alpha channel

        # Beard filter
        santa_filter = cv2.imread('filters/santa_filter.png', -1)
        santa_filter = cv2.resize(santa_filter, (beard_width*3,150))
        sw,sh,sc = santa_filter.shape
        
        for i in range(0,sw):       # Overlay the filter based on the alpha channel
            for j in range(0,sh):
                if santa_filter[i,j][3] != 0:
                    img_copy[top_lip_coords[1]+i+y-20, left_lip_coords[0]+j+x-60] = santa_filter[i,j]
                    
        # Hat filter
        hat = cv2.imread('filters/hat2.png', -1)
        hat = cv2.resize(hat, (w,w))
        hw,hh,hc = hat.shape
        
        for i in range(0,hw):       # Overlay the filter based on the alpha channel
            for j in range(0,hh):
                if hat[i,j][3] != 0:
                    img_copy[i+y-brow_coords[1]*2, j+x-left_eye_coords[0]*1 + 20] = hat[i,j]
        
        # Glasses filter
        glasses = cv2.imread('filters/glasses.png', -1)
        glasses = cv2.resize(glasses, (glasses_width*2,150))
        gw,gh,gc = glasses.shape
        
        for i in range(0,gw):       # Overlay the filter based on the alpha channel
            for j in range(0,gh):
                if glasses[i,j][3] != 0:
                    img_copy[brow_coords[1]+i+y-50, left_eye_coords[0]+j+x-60] = glasses[i,j]
        
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGRA2BGR)       # Revert back to BGR
        
        cv2.imshow('Output',img_copy)           # Output with the filter placed on the face
        cv2.imshow('Keypoints predicted',img_copy_1)        # Place keypoints on the webcam input
        
    cv2.imshow('Webcam',img)        # Original webcame Input
    
    if cv2.waitKey(1) & 0xFF == ord("e"):   # If 'e' is pressed, stop reading and break the loop
        break

Результат

Результат применения фильтра

Выше вы можете увидеть окончательный результат проекта, который содержит видео в реальном времени с фильтрами на моем лице и еще одно видео в реальном времени с нанесенными ключевыми точками.

Ограничения проекта

Хотя проект работает довольно хорошо, я обнаружил несколько недостатков, которые делают его немного не идеальным:

  • Не самая точная модель. Хотя 80%, на мой взгляд, довольно прилично, у него все еще есть много возможностей для улучшения.
  • Эта текущая реализация работает только для выбранного набора фильтров. Мне пришлось выполнить некоторые ручные настройки для более точного позиционирования и масштабирования.
  • Процесс применения фильтра к изображению довольно неэффективен по скорости вычислений, потому что для наложения изображения фильтра .png на изображение веб-камеры на основе альфа-канала мне пришлось применять фильтр попиксельно, где альфа не была равна 0. Иногда это приводит к сбою программы, когда она обнаруживает более одного лица на изображении.

Полный код проекта находится на моем Github: https://github.com/agrawal-rohit/Santa-filter-facial-keypoint-regression

Если вы хотите улучшить проект или у вас есть какие-либо предложения по решению вышеуказанной проблемы, не забудьте оставить ответ и сделать pull request в репозиторий Github. Спасибо, что заглянули, надеюсь, вам понравилось чтение.

Чао!

добавление тегов на фотографии

Туториал: добавление тегов фотографиям с генератором Tagbox для удобства поиска на MacOS

Перевод статьи Building a private, local photo search app using machine learning, автор — Aaron Edell, ссылка на оригинал — в подвале статьи. Это оно. Это лучшая проклятая вещь, которую я когда-либо делал. Обычно я не люблю хвастаться, но я так горжусь собой за…
recommedation system keras

Туториал: cоздание рекомендательной системы c библиотекой FastAI

Метод коллаборативной фильтрации в рекомендательных системах предсказывает оценку или предпочтение, которое пользователь будет отдавать объекту на основе его прежних оценок или предпочтений. Системы рекомендаций используются практически каждой крупной компанией для повышения качества предложения своих услуг. Перевод статьи «Collaborative filtering with FastAI», автор — Gilbert…
style transfer tutorial tensorflow

Туториал: перенос стиля изображений с TensorFlow

Перенос стиля (style transfer)  — одно из наиболее креативных приложений сверточных нейронных сетей. Взяв контент с одного изображения и стиль от второго, нейронная сеть объединяет их в одно художественное произведение. Перевод статьи Introduction to Neural Style Transfer with TensorFlow, автор — Marco Peixeiro, ссылка на оригинал…
big data parellel mapreduce

Туториал: параллельные вычисления больших данных с MapReduce

Метод MapReduce представляет собой технику, которая используется для обработки огромного количества данных (до нескольких петабайт). Существует много реализаций MapReduce, в том числе известный Apache Hadoop. Здесь я не буду говорить о реализациях MapReduce. Я попытаюсь представить концепцию как можно более интуитивно понятным способом, приведу реальные…
nlp javascript

Работа с NLP-моделями Keras в браузере с TensorFlow.js

Этот туториал для тех, кто знаком с основами JavaScript и основами глубокого обучения для задач NLP (RNN, Attention). Если вы плохо разбираетесь в RNN, я рекомендую вам прочитать «Необоснованную эффективность рекуррентных нейронных сетей» Андрея Карпати. Перевод статьи «NLP Keras model in browser with TensorFlow.js», автор…

Искусственная нейронная сеть с нуля на Python c библиотекой NumPy

В туториале показано, как с нуля построить искусственную нейронную сеть на Python с помощью библиотеки NumPy. Сеть будет классифицировать изображения из датасета Fruit360. Материалы туториала, за исключением цветных изображений из сета Fruit360, взяты из книги  «Practical Computer Vision Applications Using Deep Learning with CNNs» автора…
gan python keras tutorial

Туториал: создание простой GAN на Python с библиотекой Keras

В этом туториале я расскажу о генеративно-состязательных нейронных сетях (GAN) не прибегая к математическим деталям модели. Далее будет показано, как написать собственную простую GAN на Python с Keras, которая сможет генерировать знаки. Перед вам перевод статьи Demystifying Generative Adversarial Nets (GANs), опубликованной на Datacamp,…

Как использовать BERT для мультиклассовой классификации текста

Возможно, наиболее важное событие прошедшего года в NLP — релиз BERT, мультиязычной модели на основе трансформера, которая показала state-of-the-art результаты в нескольких задачах NLP. BERT — двунаправленная модель с transformer-архитектурой, заменившая собой последовательные по природе рекуррентные нейронные сети (LSTM и GRU), с более быстрым подходом…

Простая нейронная сеть в 9 строчек кода на Python

Из статьи вы узнаете, как написать свою простую нейросеть на python с нуля, не используя никаких библиотек для нейросетей. Если у вас еще нет своей нейронной сети, вот всего лишь 9 строчек кода: Перед вами перевод поста How to build a simple neural network in…
pytorch tensorflow сходства и отличия

PyTorch и TensorFlow: отличия и сходства фреймворков

В статье будет рассказано о главных сходствах и различиях между двумя популярными фреймворками глубокого обучения — PyTorch и TensorFlow. Почему такой выбор библиотек? Существует много фреймворков глубокого обучения, многие из которых жизнеспособны, но я выбрал только PyTorch и TensorFlow, так как интересно сравнить эти…

Обучение Inception-v3 распознаванию собственных изображений

В моем предыдущем посте мы увидели, как выполнять распознавание изображений с помощью TensorFlow с использованием API Python на CPU без какого-либо обучения. Мы использовали предобученную модель Inception-v3, которую Google уже обучил на тысяче классов, но что, если мы хотим сделать то же самое, но…
tensorflow mobile туториал

TensorFlow для мобильных устройств на Android и iOS

TensorFlow обычно используется для тренировки масштабных моделей на большом наборе данных, но нельзя игнорировать развивающийся рынок смартфонов и необходимость создавать будущее, основанное на глубоком обучении. Перед вами перевод статьи TensorFlow on Mobile: Tutorial, автор — Sagar Sharma. Ссылка на оригинал — в подвале статьи.  Те, кто не…
туториал распознавание изображений tensorflow

Распознавание изображений предобученной моделью Inception-v3 c Python API на CPU

Это самый быстрый и простой способ реализовать распознавание изображений на ноутбуке или стационарном ПК без какого-либо графического процессора, потому что это можно сделать лишь с помощью API, и ваш компьютер отлично справится с этой задачей. Перед вами перевод статьи TensorFlow Image Recognition Python API Tutorial,…

Нейронная сеть на JavaScript в 30 строк кода

В этой статье будет показано, как создать и обучить нейронную сеть на JavaScript, используя Synaptic.js. Этот пакет позволяет реализовывать глубокое обучение в Node.js и в браузере. Будет создана простейшая возможная нейронная сеть — та, которой удается выполнить XOR операцию. Перевод статьи «How to create…
выбор признаков нейронной сети питон

Open source инструмент на Python для выбора признаков нейронной сети

Поиск и выбор наиболее полезных признаков в датасете — одна из наиболее важных частей машинного обучения. Ненужные признаки уменьшают скорость обучения, ухудшают возможности интерпретации результатов и, что самое важное, уменьшают производительность работы.  Перевод статьи «A Feature Selection Tool for Machine Learning in Python» by William…
kaggle competition

Как попасть в топ 2% соревнования Kaggle

Статья основана на реальном опыте участия в соревнованиях на Kaggle, автор — Abhay Pawar. Ссылка на оригинал в подвале статьи.  Участвовать в соревнованиях Kaggle весело и захватывающе! За последние пару лет я разработал несколько простых способов создания более совершенных моделей машинного обучения. Эти простые, но…
transfer-learning-keras

Реализация Transfer learning с библиотекой Keras

Для большинства задач компьютерного зрения не существует больших датасетов (около 50 000 изображений). Даже при экстремальных стратегиях аугментации данных трудно добиться высокой точности. Обучение таких сетей с миллионами параметров обычно имеет тенденцию перегружать модель. В этом случае Transfer learning готов прийти на помощь. Что…

FaceNet — пример простой системы распознавания лиц с открытым кодом Github

Распознавание лица — последний тренд в авторизации пользователя. Apple использует Face ID, OnePlus — технологию Face Unlock. Baidu использует распознавание лица вместо ID-карт для обеспечения доступа в офис, а при повторном пересечении границы в ОАЭ вам нужно только посмотреть в камеру. В статье разбираемся,…

Искусственный интеллект для малого бизнеса: 5 способов применения

В массовой культуре искусственный интеллект (AI) покрыт мифами и считается исключительной силой, подрывающей экономическую стабильность. На самом деле, он похож на любую другую технологию. По мере того, как больше и больше компаний используют AI, увеличивается конкуренция и снижаются издержки — искусственный интеллект становится доступным для…
pytorch bigraph

Сверточная нейронная сеть на PyTorch: пошаговое руководство

В предыдущем вводном туториале по нейронным сетям была создана трехслойная архитектура для классификации рукописных символов датасета MNIST. В конце туториала была показана точность приблизительно 86%. Для простого датасета, как MNIST, это плохое качество. Дальнейшая оптимизация смогла улучшить результат плотно соединенной сети до 97-98% точности.…