Computer vision, Machine Learning

Обзор задачи Style Transfer (GANs) с применением сети MSG-Net-100

Время прочтения: 4 мин.

После прочтения этой статьи вы сможете самостоятельно преобразовывать картинки под выбранный вами стиль.

Суть данного подхода заключается в том, что некая картинка преобразуется в новую с другим стилем, который был задан. На входе нейронной сети подаются две картинки: контент и стиль. Чтобы было более понятно, предлагается рассмотреть пример Style Transfer на некоторых примерах.

Слева представлены изображения контента, по центру — стиль, который мы хотим применить к контенту, справа — преобразованная картинка.

В качестве основы я взял статью с arxiv’а Multi-style Generative Network for Real-time Transfer (https://arxiv.org/pdf/1703.06953.pdf), написанную Hang Zhang и Kristin Dana. Авторы статьи рассказывают как применять мульти-стайл преобразование в режиме реального времени. Код имплементации на фреймворке PyTorch выложен в репозиторий на гитхабе (https://github.com/zhanghang1989/PyTorch-Multi-Style-Transfer).  Обзор сети MSG представлен на рисунке снизу. Сначала происходит генерация картинок с помощью сиамской сети (архитектура нейросети, которая обучается дифференцированию входных данных, при этом каждый класс берёт лишь по одному примеру из выборки).  Затем, авторы статьи добавляют предобученную нейронную сеть (в данном случае — VGG). Сеть выдаёт всё более релевантные результаты благодаря минимизации функции потерь (loss function), которая сводит к минимуму разницу между основной картинкой и стилем с таргетами.

https://arxiv.org/pdf/1703.06953.pdf

Существует несколько подходов для измерения функции потерь в задаче Style Transfering (переноса стиля). В данной статье авторы минимизируют взвешенное сочетание разницы стиля и содержания фотографии выходов генеративной сети от таргетов для данной предобученной сети ‘F’. Обозначим генеративную сеть как ‘G’. Индекс ‘c’ обозначает изображения контента (content), а индекс ‘s’ — изображения стиля. В итоге стоит задача минимизации потерь. Лямбда с индексами ‘c’ и ‘s’ — это сбалансированные веса (схоже с классической регуляризацией). ‘l’ с индексом ‘TV’ — это полная регуляризация вариаций, которая использовалась для поощрения хорошо сгенерированных изображений.

https://arxiv.org/pdf/1703.06953.pdf

Авторы статьи используют предобученную на датасете ImageNet сеть VGG с 16 слоями. Функции активации — ReLU. Датасет Microsoft COCO был использован для отбора фотографий контента изображений, который насчитывает около 80000 изображений. Изображения для стиля содержат 1000 картинок, которые состоят из 100 изображений с предыдущих работ с задачами style transfer и 900 изображений живописи, взятых из датасета wikiart.org. В качестве оптимизатора выступал Adam. Для обучения нейронной сети был произведён ресайз каждой фотографии к размеру 256х256, сеть обучалось с батч сайзом 4, всего было 80000 итераций. Сеть MSG-Net-100 обучалось около 8 часов на видеокарте TITAN, а вес модели занимает 9.6 Мб, в результате у сети было 2.3 миллиона параметров. Также, авторы статьи обучили сеть MSG-Net-1K, которая обучалась на 320 тысячах итерациях (в 4 раза больше, чем MSG-Net-100) на картинках размером 64×64, в результате которой было 8.9 миллионов параметров.

Архитектура сети состоит из нескольких частей. Inspiration слой настраивает карту признаков с помощью матрицы Грама. ConvLayer представляет собой обычный сверточный слой Conv2d с паддингом. Bottleneck — это слой, который содержит несколько узлов по сравнению с предыдущими слоями. Его можно использовать для получения представления ввода с уменьшенной размерностью. По ссылке в конце статьи, вы можете просмотреть сеть в виде кода самостоятельно. В репозитории приложены веса модели, которые можно загрузить с помощью метода load_state_dict.

Чтобы посмотреть, как работает данная сеть на собственных примерах, предлагается ознакомиться с кодом, который представлен в ноутбуке и может быть открыт в google colab или jupyter notebook. Ссылка на ноутбук: https://github.com/germanjke/StyleTransformerGANs/blob/master/StyleTransferGANs_demo_ru.ipynb

Полезные ссылки для более детального погружения в мир GANов (генеративно-состязательных сетей):

Обзор GANов в статье https://arxiv https://arxiv.org/abs/1710.07035

Введение в GANы для новичков https://www.youtube.com/watch?v=8L11aMN5KY8

Семинар от Deep Learning School МФТИ (на русском языке) https://www.youtube.com/watch?v=u2HDm7YSwoA

Советуем почитать