Время прочтения: 5 мин.

Одной из самых распространенных задач машинного обучения является сегментация объектов на изображении. Это крайне важный этап, так как в последующем, сегментированные изображения используются для обучения классификации либо же для вывода уже готового изображения. Поэтому к сегментации стоит относиться трепетно, и сегодня мы рассмотрим один из примеров, как это делать. Пример будет построен на сегментации животных: собак и кошек – на изображении.

Начало работы выглядит одинаково для всех моделей машинного обучения: подключаем библиотеки, загружаем датасет, добавляем фото с помощью поворотов и кадрирования и обрабатываем каждое фото, так же стандартно. Такую последовательность вы можете найти в любой статье, связанной с машинным обучением, в частности, у моего коллеги выходила статья по классификации изображений (ссылка), где он подробно описал каждый шаг. Поэтому не будем останавливаться на этом и сразу перейдем к построению.

Листинг 1. Построим входной конвейер, применив расширение после пакетирования входных данных.

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

Листинг 2. Посмотрим на пример изображений и маски на них.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()


for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])

Далее нужно определить модель.

Модель используется здесь модифицированная U-Net. U-Net состоит из кодировщика (субдискретизатора) и декодера (повышающего дискретизатора). Чтобы изучить надежные функции и уменьшить количество обучаемых параметров, мы будем использовать предварительно обученную модель — MobileNetV2 — в качестве кодировщика. Для декодера, мы будем использовать повышающий дискретизацию блок, который уже реализован в Pix2pix учебник в TensorFlow Примеры Repo.

Как упоминалось выше, кодер будет предварительно обученная модель MobileNetV2, которая готова к использованию в tf.keras.applications. Кодировщик состоит из определенных выходных данных промежуточных слоев модели. Учтите, что кодировщик не будет обучаться в процессе обучения.

Листинг 3. Определяем модель.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Используем активации этих слоев
layer_names = [
    ‘block_1_expand_relu’,   # 64×64
    ‘block_3_expand_relu’,   # 32×32
    ‘block_6_expand_relu’,   # 16×16
    ‘block_13_expand_relu’,  # 8×8
    ‘block_16_project’,      # 4×4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Создадим модель извлечения признаков
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

Декодер / повышающий дискретизатор — это просто серия блоков повышающей дискретизации, реализованная в примерах TensorFlow.

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Даунсэмплинг по модели
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Повышение дискретизации и установление пропуска соединений
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # Это последний слой модели
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Следует отметить, что количество фильтров на последнем слое устанавливается на число output_channels. Это будет один выходной канал на класс.

Далее обучаем модель.

Теперь осталось только скомпилировать и обучить модель.

Поскольку это мультиклассицирует проблему определения классификации CetegforicalCrossentropy с from_logits=True, которая является стандартной функцией потерь. Используем losses.SparseCategoricalCrossentropy(from_logits=True), так как метки являются скалярными целыми вместо vecrtors партитур for баллов для каждого класса for каждого пикселя.

Листинг 4. При выполнении вывода метка, присвоенная пикселю, — это канал с наибольшим значением. Это то, что делает функция create_mask.

OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Листинг 5. Обратный вызов, определенный ниже, используется для наблюдения за тем, как модель улучшается во время обучения.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

Делаем предсказания

Теперь сделаем несколько прогнозов. В целях экономии времени количество эпох оставлено небольшим, но вы можете установить его больше, чтобы получить более точные результаты.

show_predictions(test_batches, 3)

Вывод: здесь мы рассмотрели методы сегментации объектов на изображении. Это будет полезно как при участии в различных конкурсах, направленных на машинное обучение, так и в повседневной работе. Например, когда на фотографиях вам надо найти заданные объекты: документ, запрещенный прибор, средства пожарной безопасности и многое другое. В целом, возможности безграничны, все упирается в достаточно большой и качественный датасет. Ведь минимум 50% успеха работы нейросети будет зависеть от данных, на которых она обучается. Желаю успехов в будущих проектах.