Время прочтения: 4 мин.
PyCaret — достаточно свежая Python библиотека, предназначенная для машинного обучения. Она достаточно проста в использовании и интуитивно понятна. Сегодня постараемся разобраться в некоторых тонкостях данной библиотеки на примере встроенного в PyCaret датасета. Для начала установим библиотеку и проверим ее версию:
pip install pycaret
from pycaret.utils import version
version()
Актуальная на данный момент версия 2.3.1. Далее импортируем нужные модули с датасетами и классификацией, и выведем стандартные для этой библиотеки датасеты:
from pycaret.datasets import get_data
from pycaret.classification import *
index = get_data('index')
Разберем задачу классификации на датасете ‘Juice’ (размером в 1070 строк и 19 столбцов):
data = get_data('juice')
Датасет представляет из себя стандартный pandas dataframe.
Библиотека работает достаточно быстро, давайте посмотрим на свойства датафрейма, передадим в setup данные и целевую переменную:
clf = setup(data, target = 'Purchase', session_id = 1)
Узнаем информацию о задаче (бинарная классификация), пропущенных значениях, размерности (также размерности при делении на train и test), и многое другое (в нашем случае 58 строк различной информации).
Давайте найдем оптимальные модели для наших данных и сравним их:
%%time
best_model = compare_models()
На выходе получаем результаты 13 различных моделей за 20 секунд. И все это в одну строчку кода. Модели разные — от логистической регрессии до наивного Байеса, а обучение происходит на 10 фолдах. В таблице указаны основные метрики (Accuracy, AUC, F1) и интересная информация о времени выполнения. В нашем случае логистическая регрессия показывает accuracy — 0.8343, а Ridge Classifier имеет 0.8289, но отрабатывает в несколько раз быстрее.
Для тех, кому хочется знать все, что “под капотом” у данных моделей:
models()
Видим библиотеки, с помощью которых и функционирует PyCaret.
Также можем создать отдельно модель (с самым высоким Accuracy) логистической регрессии и подробнее рассмотреть метрики полученные с помощью данной модели (построим ROC-кривую и матрицу ошибок):
lr = create_model(‘lr’)
plot_model(lr)
plot_model(lr, plot = 'confusion_matrix')
Хочется отметить, что графики получаются достаточно информативными и красивыми из коробки. Также можем построить график с feature importance и посмотреть на важность признаков:
plot_model(lr, plot = 'feature')
Ну и конечно можно сделать предсказание:
pred_holdouts = predict_model(lr)
В итоге получаем очень мощный инструмент, который способен ускорить получение некого бейзлайна для некоторых задач. Конечно, стоит отметить, что в данной статье мы использовали датасет с 1 тысячей строк и все может быть не так радужно на больших наборах данных.