Анализ процессов, Графы

Реализация деревьев решений простым кодом

Время прочтения: 5 мин.

Пришлось нам как-то проводить аудит одной модели. В модели на основе достаточно большого набора данных с помощью хитрого алгоритма принимается решение. Нам нужно было проверить насколько адекватно считаются показатели при принятии решения.

Так вот, проверяем мы данные и видим, что имеют место быть аномалии, т.е. отклонения от нормального значения данных. Соответственно, эти аномалии надо проверить. Для этого нужно воспроизвести модель, по которой эти аномалии возникают.

Анализ показал, что решение наиболее просто получить именно с помощью решающего дерева. Так вот, займемся его воспроизведением.

Данные у нас берутся не из воздуха, а хранятся в БД — у нас это MS SQL Server. Сделаем обвязку для считывания. На вход подаем в виде строки DBpath – путь к серверу, DBname – наименование базы данных.

package TreeProject;

import java.io.InputStream;
import java.sql.*;
import java.util.HashMap;
import java.util.Objects;
import java.util.Scanner;
import java.util.Vector;

public class DBConnector {
    private final String DBpath;
    private final String DBname;
    private final String driver;
    private String query;

    public DBConnector(String DBpath, String DBname){
        this.DBpath = DBpath;
        this.DBname = DBname;
        this.driver = "com.microsoft.sqlserver.jdbc.SQLServerDriver";
    }

    public void getQueryString(String filename){
        InputStream file = getClass().getClassLoader().getResourceAsStream(filename);
        Scanner scanner = new Scanner(Objects.requireNonNull(file));
        this.query = "";
        scanner.useDelimiter("line.separator");
        while (scanner.hasNext()){
            this.query += scanner.next();
        }
    }

    public Vector<HashMap<String, String>> getData(){
        Connection conn;
        String connStr = "jdbc:sqlserver://"+ this.DBpath +";databaseName="+ this.DBname +";integratedSecurity=true";
        Statement stmt;

        try {
            Class.forName(driver);
            conn = DriverManager.getConnection(connStr);
            stmt = conn.createStatement();
            ResultSet response = stmt.executeQuery(this.query);
            Vector<HashMap<String, String>> data = new Vector<>();
            while (response.next()) {

                HashMap<String, String> row = new HashMap<>();
                row.put("DATA_A", response.getString("DATA_A"));
                row.put("DATA_B", response.getString("DATA_B"));
                row.put("DATA_C", response.getString("DATA_C"));

                if (response.getString("NODE") == "null" || response.getString("NODE") == null)
                {
                    row.put("NODE", "-1");
                }
                else
                {
                    row.put("NODE", response.getString("NODE"));
                }
                data.add((HashMap<String, String>) row.clone());
            }
            conn.close();
            return data;
        } catch (Exception e) {
            e.printStackTrace();
            return  null;
        }
    }
}

Код решающего дерева реализован в самом элементарном виде – через if/else. Но тем не менее даже в таком виде он исправно работает. Потому что самое главное в решающем дереве это грамотная разбивка данных в узлах решающего дерева. Точность разбивки обеспечивается во-первых – огромными массивами данных для определения весов в узлах. А далее периодической калибровкой этих весов на обновленных данных. Само дерево выглядит примерно так:

package TreeProject;

import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;

/**
* решающее дерево для расчета показателей
*/
public class Tree {
    private Double DATA_A; 
    private Double DATA_B; 
    private Double DATA_C; 

    // result
    private int NODE;

    Tree() {}

    public Tree(Double DATA_A, Double DATA_B, Double DATA_) {
        this.DATA_A = DATA_A;
        this.DATA_B = DATA_B;
        this.DATA_C = DATA_C;
    }

    public void calcNode() {
        if(DATA_A != null) {
            if(DATA_B >= 10) {
                if(DATA_B != null && DATA_B > 5) {
                    NODE  = 1;
                } else {
                    NODE  = 2;
                }
            } else {
                if(DATA_B != null && DATA_B < 5) {
                    NODE  = 3;
                } else {
                    NODE  = 4;
                }
            }
        } else {
            NODE  = 5;
        }
    }
}

Далее подружим данные, полученные из базы данных, и методы, реализованные в решающем дереве. На выходе получим расчет показателей (записываются в файл output.csv), так же для работы с БД используется внешний файл с SQL запросом — запрос элементарный вида SELECT * FROM table, поэтому его код не приводится:

package Result;

import DBConnector;
import Tree;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.text.ParseException;
import java.util.HashMap;
import java.util.Vector;

public class Calc {

    public static void main(String[] args) throws ParseException, UnsupportedEncodingException, FileNotFoundException {
        getAndCalc("data.sql", "output.csv");
    }

    public static void getAndCalc(String queryFile, String outFile) throws ParseException, FileNotFoundException, UnsupportedEncodingException {
        DBConnector dbConnector = new DBConnector("server", "database");
        dbConnector.getQueryString(queryFile);
        Vector<HashMap<String, String>> data= dbConnector.getData();
        PrintWriter writer = new PrintWriter(outFile, "UTF-8");

        writer.println("DATA_A," + "DATA_B," + "DATA_C");

        for (HashMap<String, String> dataRow : data) {
            System.out.println(dataRow);

            Tree tree = new Tree(
                    Double.parseDouble(dataRow.get("DATA_A")),
                    Double.parseDouble(dataRow.get("DATA_B")),
                    Double.parseDouble(dataRow.get("DATA_C"))
            );
            writer.println(tree.calcNode());
        }
        writer.close();
    }
}

Результат работы модели доступен для анализа, и мы можем проверить итоги на наличие аномалий. Результат выглядит примерно вот так:

Наименование показателяЗначение при долгом циклеЗначение при коротком циклеВзвешенное значениеИтоговая NODE
DATA_A18,11005003500,21
DATA_B2,08,55,54
DATA_C15,612,26

И уже возможна проверка данных на воспроизведенной модели и сравнение полученных данных из воспроизведённой модели с тем, что у нас считает «рабочая» модель. Так же проводится выявление аномалий – к примеру, выявление возраста человека более 100 лет или подобные не логичные данные в результатах.

Мы видим, что модель дерева легко реализуется даже без всяких библиотек и при этом очень хорошо понятна логика, по которой модель думает.

Показатели для выбора условий в дереве определяются как раз на основе БОЛЬШИХ ДАННЫХ, про которые мы упоминали выше. Именно там определяют, что к примеру фактор f1 — это фактор важнее фактора f2, а фактор f3 вообще по порядку шестнадцатый.

Если кто внезапно не помнит или не знает, как именно определяется значимость факторов – читаем про критерий Джини. Вкратце, это означает, что изначально данные предполагаются абсолютно неупорядоченными и оценивается показатель данной неупорядоченности — энтропия.

Так вот, любое разбиение данных деревом (или ифами в нашем случае) должно эту энтропию снижать. И в итоге мы должны получить живую и румяную разбитую и упорядоченную выборку. Собственно, так мы и получаем итог, когда на основании кучи входных данных мы получаем однозначный ответ.

Итого, мы показали, что самые простые решения в большей части и самые рабочие. Кроме того, умение применять на практике простые решения там, где можно обойтись без сложных и делает тебя профи.

Советуем почитать