Время прочтения: 4 мин.
Обучать модель регрессии будем на примере решения классической задачи по предсказанию выживаемости пассажиров Титаника.
Подготовка IDE
Первым делом необходимо обновить Visual Studio и установить расширение ML.NET Model Builder (ссылка).
На момент написания статьи последняя версия библиотеки Microsoft.ML 1.6.0. Она написана под .NET Standard 2.0, таким образом ее можно использовать в проектах .NET Framework 4.6.1+\.NET Core 2.0+\.NET 5.0+.
Создание проекта
Я создал проект .net Framework 4.7.2. Расширение ML.NET Model Builder находится в стадии активной разработки и для его включения нужно перейти в раздел меню средства\параметры\окружение\функции предварительной версии и поставить галочку:
Обучение модели
Добавим модель машинного обучения в проект. Для этого откройте меню «Добавить» и выберите «Машинной обучение»
Назовем модель «TitanikModel»
После добавления должен открыться Model Builder.
Мы решаем задачу регрессии, поэтому выберем сценарий «Прогнозирование значений».
Я предварительно загрузил файлы Train.csv и Test.csv. Укажем путь до первого в разделе «Данные». Выберем прогнозируемый столбец «Survived».
Можно более детально настроить данные перейдя в расширенные параметры:
Идентификатор пассажира, ФИО, номер билета и номер каюты я счел лишними и проставил напротив этих столбцов значение «Ignore».
Перейдем в раздел «обучение». Здесь нужно указать время обучения. Чем больше это значение, тем точнее будет модель. Я оставлю 10 секунд и запущу обучение.
После обучения в проект добавится 3 файла:
TitanikModel.consumption.cs – описание входных\выходных данных, реализация функции Predict
TitanikModel.training.cs – код обучения модели
TitanikModel.zip – модель
Использование модели
Попробуем предсказать выживаемость для тестовой выборки
string testPath = "test.csv";
string resultPath="predict.csv";
//индексы колонок в .csv файле
int colPassengerId=0,
colPclass=1,
colName=2,
colSex=3,
colAge=4,
colSibSp=5,
colParch=6,
colTicket=7,
colFare=8,
colCabin=9,
colEmbarked=10;
float thresold = 0.5f;
Regex CSVParser = new Regex(",(?=(?:[^\"]*\"[^\"]*\")*(?![^\"]*\"))");
using (StreamWriter writer = File.CreateText(resultPath))
{
writer.WriteLine("PassengerId,Survived"); //заполняю заголовок
string[] testRows = File.ReadAllLines(testPath);
for(int i=1;i<testRows.Count();i++)
{
string[] row = CSVParser.Split(testRows[i]);
//Подготовка данных
TitanikModel.ModelInput sampleData = new TitanikModel.ModelInput()
{
Pclass = float.Parse(row[colPclass], CultureInfo.InvariantCulture.NumberFormat),
Sex = row[colSex],
Age = string.IsNullOrEmpty(row[colAge])?0f:float.Parse(row[colAge], CultureInfo.InvariantCulture.NumberFormat),
SibSp = float.Parse(row[colSibSp], CultureInfo.InvariantCulture.NumberFormat),
Parch = float.Parse(row[colParch], CultureInfo.InvariantCulture.NumberFormat),
Fare = string.IsNullOrEmpty(row[colFare]) ? 0f : float.Parse(row[colFare], CultureInfo.InvariantCulture.NumberFormat),
Embarked = row[colEmbarked],
};
float result = TitanikModel.Predict(sampleData).Score; //Predict
string toWrite = $"{row[colPassengerId]},{(result > thresold ? "1" : "0")}";
Console.WriteLine(toWrite);
writer.WriteLine(toWrite); //если predict > порогового значения - значит выжил
}
}
Console.WriteLine("Завершено");
Console.ReadKey();
Этот код предскажет значение столбца Survived для каждой строки выборки и сохранит его в predict.csv. Загрузим этот файл на Kaggle
Я никак не подготавливал данные, не экспериментировал с пороговым значением (threshold) и получил результат статистически выше среднего. Библиотека Micorsoft.ML показала себя с хорошей стороны: наличие документации, постоянные обновления, хорошая производительность и графический интерфейс для обучения моделей.
Исходный код проекта можно скачать с Github.