Время прочтения: 5 мин.

Задача классификации в работе крупных компаний встречается довольно часто — определить принадлежность задания к отделу, отнести полученное на почте письмо к полезному или нежелательному, выявить неисправность колёсных пар вагона по издаваемому при движении звуку и многое другое.

Одна из важнейших задач, решаемых, например, в банке — одобрить или отказать в выдаче кредита.

Для решения такого рода задач разрабатываются различные алгоритмы, библиотеки или фреймворки. Один из фреймворков — Microsoft Cognitive Toolkit (CNTK), проект с открытым исходным кодом для глубокого обучения на базе сверточных нейронных сетей. С помощью данного фреймворка можно использовать и комбинировать различные модели машинного обучения (DNN, CNN, RNN). Также фреймворк имеет возможность обучения на нескольких видеоадаптерах одновременно.

Один из возможных вариантов применения фреймворка — создание и обучение сверточной нейронной сети по модели логистической регрессии.

Логистическая регрессия — это простая модель для выполнения бинарной классификации. Как и другие модели регрессии, логистическая регрессия моделирует взаимосвязь между независимой переменной x и зависимой переменной z посредством линейной комбинации x с параметрами w и b. Модель изучает веса w и b и определяет, что при высокой вероятности того, что значение x соответствует у = 1, то и вес w должен быть большим, а при высокой вероятности соответствия x y=0 — вес минимален.

Так как мы заинтересованы в предсказании двух меток 0 и 1 мы преобразуем предсказание с помощью логистической функции: sigma (x) = 1 / (1 + exp (-x)).

Таким образом, применение логистической функции дает «округления» предсказаний. Значения больших положительных свидетельств становятся близкими к 1, а значения больших отрицательных свидетельств становятся близкими к 0. Это позволяет нашей модели выводить вероятности.

Для примера использования мы взяли небольшой датасет кредитного скоринга, на котором обучили линейную модель логистической регрессии для классификации лиц запрашивающих кредит на надежных и ненадежных.

Таблица в файле:

Код написан на языке программирования C#. Сперва необходимо загрузить данные из нашего датасета, выделить поле для классификации и параметр, по которому будет происходить классификация. Для этого через OLEDB адаптер загружаем наш датасет в объект DataTable, затем извлекаем из него необходимые столбцы и преобразуем их в Value объект:

private static void GetData(string path, bool isFirstRowHeader, string fCol, string lCol,  int sampleSize, int inputDim, int numOutputClasses, out Value featureValue, out Value labelValue, DeviceDescriptor device)
{
var dataTable = new DataTable();
var header = isFirstRowHeader ? "Yes" : "No";
var pathOnly = Path.GetDirectoryName(path);
var fileName = Path.GetFileName(path);
var sql = @"SELECT * FROM [" + fileName + "]";

using (OleDbConnection connection = new OleDbConnection(
@"Provider=Microsoft.Jet.OLEDB.4.0;Data Source=" + pathOnly + ";Extended Properties=\"Text;HDR=" + header + "\""))
using (OleDbCommand command = new OleDbCommand(sql, connection))
using (OleDbDataAdapter adapter = new OleDbDataAdapter(command))
            {
                dataTable.Locale = CultureInfo.CurrentCulture;
                adapter.Fill(dataTable);
            }
var oneHotLabels = dataTable.Rows.OfType<DataRow>().Select(dr => dr.Field<string>(lCol) == "Fully Paid" ? 1.0f: 0.0f).ToArray();
var features = dataTable.Rows.OfType<DataRow>().Select(dr => dr.Field<float>(fCol)).ToArray();
featureValue = Value.CreateBatch<float>(new int[] { inputDim }, features, device);
labelValue = Value.CreateBatch<float>(new int[] { numOutputClasses }, oneHotLabels, device);
}

После получения данных необходимо создать модель для обучения.

Plus, Times, Sigmoid — базовые функции CNTK для создания моделей. Входным аргументом может быть переменная или другая функция CNTK. Эти методы создают простую сеть с параметрами, которые настраиваются на этапе обучения, чтобы получился мультиклассовый классификатор. Для нашей задачи была использована функция Times():

private static Function CreateLinearModel(Variable input, int outputDim, DeviceDescriptor device)
{
var inputDim = input.Shape[0];
var wParam = new Parameter(new int[] { outputDim, inputDim }, DataType.Float, 1, device, "w");
var bParam = new Parameter(new int[] { outputDim }, DataType.Float, 0, device, "b");
return CNTKLib.Times(weightParam, input) + biasParam;
}
var loss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelVariable);
var evalErr = CNTKLib.ClassificationError(classifierOutput, labelVariable);
CNTK.TrainingParameterScheduleDouble learningRatePerSample = new CNTK.TrainingParameterScheduleDouble(0.02, 1);
IList<Learner> parameterLearners = List<Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) };
var trainer = Trainer.CreateTrainer(classifierOutput, loss, evalError, parameterLearners);

Для обучения CNTK использует алгоритм стохастического градиентного спуска(SGD). SGD итеративно просматривает некоторое фиксированное подмножество обучающих примеров и обновляет параметры в направлении градиентов функции стоимости после каждого шага. Если смотреть на мини-пакет вместо полных данных, каждое обновление следует только за грубым приближением к конечной цели обучения. «Стохастический» характер SGD помогает выйти из локальных оптимумов и оказался одновременно простым и эффективным для поиска решений.

В нашей задаче мы использовали стандартные параметры для обучения, поэтому дополнительной настройки алгоритма не потребовалось:

// train the model
for (int minibatchCount = 0; minibatchCount < numMinibatchesToTrain; minibatchCount++)
{
Value features, labels;
GetData("./credit_test.csv", true, "Number of Credit Problems", "Loan Status", minibatchSize, inputDim, numOutputClasses, out features, out labels, device);
trainer.TrainMinibatch(new Dictionary<Variable, Value>() { { featureVariable, features }, { labelVariable, labels } }, device);
TestHelper.PrintTrainingProgress(trainer, minibatchCount, updatePerMinibatches);
}

После обучения с помощью функции Evaluate() точность модели возрастает. На вход подается массив данных из тестовой выборки, полученный через свойство модели Output вывод сравнивается с эталонным и выводится количество отклонений и общее число вариантов:

// test and validate the model
var testSize = 100;
Value testFeatureValue, expectedLabelValue;
GetData("./credit_test.csv", true, "Number of Credit Problems", "Loan Status", testSize, inputDim, numOutputClasses, out testFeatureValue, out expectedLabelValue, device);

// GetDenseData just needs the variable's shape
var expectOneH= expectedLabelValue.GetDenseData<float>(labelVariable);
var expectLabels = expectedOneHot.Select(l => l.IndexOf(1.0F)).ToList();

var inputDataMap = new Dictionary<Variable, Value>() { { featureVariable, testFeatureValue } };
var outputDataMap = new Dictionary<Variable, Value>() { { classifierOutput.Output, null } };
classifierOutput.Evaluate(inputDataMap, outputDataMap, device);
var outputValue = outputDataMap[classifierOutput.Output];
IList<IList<float>> actualLabelSoftMax = outputValue.GetDenseData<float>(classifierOutput.Output);
var actualLabels = actualLabelSoftMax.Select((IList<float> l) => l.IndexOf(l.Max())).ToList();
int misMatches = actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 0 : 1).Sum();
Console.WriteLine($"Validating Model: Total Samples = {testSize}, Misclassify Count = {misMatches}");

В результате обучения на датасете из 100000 записей получились следующие показатели модели:

Была продемонстрирована довольно высока точность классификатора, равная 0.94.

В ходе использования данного фреймворка я выделил для себя следующие положительные стороны:

— быстрое обучение, так как используется графический адаптер, а при наличии нескольких — скорость еще возрастет;

— возможность использования с разными языками программирования (есть API для c#, c++, python);

Из минусов — сложная реализация и недостаточное документирование, это значительно увеличивает время разработки приложения, а также сложность в настройке для начинающих пользователей.

Помимо моделей классификации на основе логистической регрессии, фреймворк имеет другие возможности для работы со звуком, видео, текстом. Предлагаю всем изучить его и поделиться своими наблюдениями.

Спасибо за внимание.