Время прочтения: 4 мин.

Обучать модель регрессии будем на примере решения классической задачи по предсказанию выживаемости пассажиров Титаника.

Подготовка IDE

Первым делом необходимо обновить Visual Studio и установить расширение ML.NET Model Builder (ссылка).

На момент написания статьи последняя версия библиотеки Microsoft.ML 1.6.0. Она написана под .NET Standard 2.0, таким образом ее можно использовать в проектах .NET Framework 4.6.1+\.NET Core 2.0+\.NET 5.0+.

Создание проекта

Я создал проект .net Framework 4.7.2. Расширение ML.NET Model Builder находится в стадии активной разработки и для его включения нужно перейти в раздел меню средства\параметры\окружение\функции предварительной версии и поставить галочку:

Обучение модели

Добавим модель машинного обучения в проект. Для этого откройте меню «Добавить» и выберите «Машинной обучение»

Назовем модель «TitanikModel»

После добавления должен открыться Model Builder.

Мы решаем задачу регрессии, поэтому выберем сценарий «Прогнозирование значений».

Я предварительно загрузил файлы Train.csv и Test.csv. Укажем путь до первого в разделе «Данные». Выберем прогнозируемый столбец «Survived».

Можно более детально настроить данные перейдя в расширенные параметры:

Идентификатор пассажира, ФИО, номер билета и номер каюты я счел лишними и проставил напротив этих столбцов значение «Ignore».

Перейдем в раздел «обучение». Здесь нужно указать время обучения. Чем больше это значение, тем точнее будет модель. Я оставлю 10 секунд и запущу обучение.

После обучения в проект добавится 3 файла:

TitanikModel.consumption.cs – описание входных\выходных данных, реализация функции Predict

TitanikModel.training.cs – код обучения модели

TitanikModel.zip – модель

Использование модели

Попробуем предсказать выживаемость для тестовой выборки

string testPath = "test.csv";
string resultPath="predict.csv";
//индексы колонок в .csv файле
int colPassengerId=0, 
    colPclass=1, 
    colName=2, 
    colSex=3, 
    colAge=4, 
    colSibSp=5, 
    colParch=6, 
    colTicket=7, 
    colFare=8, 
    colCabin=9, 
    colEmbarked=10;
float thresold = 0.5f;

Regex CSVParser = new Regex(",(?=(?:[^\"]*\"[^\"]*\")*(?![^\"]*\"))");
using (StreamWriter writer = File.CreateText(resultPath)) 
{
    writer.WriteLine("PassengerId,Survived");           //заполняю заголовок
    string[] testRows = File.ReadAllLines(testPath);
    for(int i=1;i<testRows.Count();i++)
    {
        string[] row = CSVParser.Split(testRows[i]);
        //Подготовка данных
        TitanikModel.ModelInput sampleData = new TitanikModel.ModelInput()
        {
            Pclass = float.Parse(row[colPclass], CultureInfo.InvariantCulture.NumberFormat),
            Sex = row[colSex],
            Age = string.IsNullOrEmpty(row[colAge])?0f:float.Parse(row[colAge], CultureInfo.InvariantCulture.NumberFormat),
            SibSp = float.Parse(row[colSibSp], CultureInfo.InvariantCulture.NumberFormat),
            Parch = float.Parse(row[colParch], CultureInfo.InvariantCulture.NumberFormat),
            Fare = string.IsNullOrEmpty(row[colFare]) ? 0f : float.Parse(row[colFare], CultureInfo.InvariantCulture.NumberFormat),
            Embarked = row[colEmbarked],
        };
        float result = TitanikModel.Predict(sampleData).Score; //Predict

        string toWrite = $"{row[colPassengerId]},{(result > thresold ? "1" : "0")}";
        Console.WriteLine(toWrite);
        writer.WriteLine(toWrite); //если predict > порогового значения - значит выжил
    }
}
Console.WriteLine("Завершено");
Console.ReadKey();

Этот код предскажет значение столбца Survived для каждой строки выборки и сохранит его в predict.csv. Загрузим этот файл на Kaggle

Я никак не подготавливал данные, не экспериментировал с пороговым значением (threshold) и получил результат статистически выше среднего. Библиотека Micorsoft.ML показала себя с хорошей стороны: наличие документации, постоянные обновления, хорошая производительность и графический интерфейс для обучения моделей.

Исходный код проекта можно скачать с Github.