決定木による学習

決定木は、根から順番に条件分岐を辿っていくことで結果を返します。特に分類問題で用いられる場合には「分類木」と呼ばれます。ここでは、与えられたデータを元に、自動的に分類木を構成する方法について説明します。機械学習の手法の中でも、学習結果を人間が解釈し易いことが特徴で、データの特徴を掴む場合によく用いられるようです。(「花びらの長さがθ以下ならチューリップ」などが分かれば、それだけで面白いですよね!)今回はアルゴリズムについて説明をした後、pythonでの実装例を紹介したいと思います。

CARTによる分類木の構築

CARTとはClassification And Regression Treesの略で、2進木の決定木の学習アルゴリズムです。決定木の学習アルゴリズムは他にもいくつか種類が有りますが、ここではCARTに絞って動作を見ていきたいと思います。

動作の流れ

まず最初にルートノードに全データが与えられている状態を考えます。

この全データを最も上手く分割する基準を探し(下で説明します)、データを2グループに分けます。自分の下に子ノードを2つ作成し、先ほど分割したデータをそれぞれ割り当てます。

左の子ノードは1種類のデータのみになったのでここで終了です。右の子ノードはまだデータが混ざっているので、分割条件を探してデータを2つに分割します。あとはこの動作を繰り返し、再帰的に分割を続けていくと決定木が出来上がっていく、という流れです。

分割規則

ここではデータを2つに分割する方法について考えてみます。最初に、ノードに含まれるデータの「不純度」を定義し、分割による不純度の変化から、分割の良さを表す「分割指数」を定義します。最後に、分割の候補の選び方を紹介します。

(1) ノードの不純度

ノードtに含まれるデータの不純度をΙ(t)と定義します。ここでは不純度の評価にGini関数を使います。

(2) 分割の評価方法

次に、分割条件の善し悪しを判断するために、分割前後の不純度の変化を次式で定義します。(不純度にGini関数を用いる場合はGini Indexと呼ばれます。)

tはノードt、tLとtRは分割後の2つのグループを表します。pL、pRは分割後のノードの割合です。複数の分割条件sを試し、この分割指数が最大の分割を採用します。

(3) 分割候補の選び方

特徴量が数値の場合は次の考え方で候補を挙げることが出来ます。 「データをある特徴量に着目して小さい順に並べ、隣接する各ペアの中間地点を分割候補点とする。」 上の図では、特徴量はx1とx2です。まず、x1の小さい順でデータを並べ、隣り合うデータのx1の平均値が分割候補となります。x2についても同様に小さい順に並べ、各中間点を候補とします。これら全ての候補点で(2)の分割指数を計算し、最も分割指数の高い分割方法を採用します。

木の剪定

(1) 過学習の問題

まず質問。この分類をみてどう思いますか?

明らかにやり過ぎですね。与えられたデータはきちんと分けられていますが、過学習に陥ってしまっています。この状態に成らないために、「丁度良いところで止めたい」と思うのは自然だと思います。

多少分類しきれないデータもありますが、この方がデータの特性をよく表していますよね。

(2) バランスの取り方

「学習度合い」と「木の複雑さ」のバランスを取るための方法は大きく分けると以下の2つ。

  • a: 複雑になりすぎるまえに分岐処理を止める
  • b: とりあえず分岐させておいて、後で不要な枝を剪定する

今回の実装編は(b)の方法でやっています。ここで問題になるのが「どの枝を剪定するか」という部分。以下のような基準で剪定する枝を選びます。

Fr = (分割指数) × (ノードに割り当てられたデータ数の割合)

分割指数が小さいのは、分割しても不純度があまり減らない「良くない分割」です。また、データ数の割合が小さいのは、細かい部分の分割を表しています。この2つの積であるFrに閾値を設けて、閾値に満たないノードは分割を取り消します。この評価を、先端のノードから順に行い、統合するかどうかを判断していきます。ちなみに、この閾値を調整することで、木の複雑さを制御できるようになります。

実装編

(1) 方針について

オブジェクト指向っぽく書きます。

まずは学習器本体の”DecisionTree”クラス。機能は以下の通り

  • 生成時に調整パラメータを指定
  • fit(data, target)で学習
  • predict(data)で予測
  • print_tree()でツリーの情報を吐き出し

次に木構造の単一ノードを表す”_Node”クラス。自ノードの配下に子ノードを保持することが出来ます。ファイル内部でしか使わないので、クラス名の頭にアンダースコアをつけてます。機能は以下の通り。

  • build(data, target)で分類木の構築を行う
    受け取ったデータを分割する規則を探し、子ノードを作って再帰的に呼び出します。
  • prune(criterion, numall)で木の剪定を行う
  • print_tree()でツリー情報を吐き出し

(2) コード

# -*- coding: utf-8 -*-

import numpy as np


class _Node:
    """決定木のノードクラス"""

    def __init__(self):
        """初期化処理

        left       : 左の子ノード(しきい値未満)
        right      : 右の子ノード(しきい値以上)
        feature    : 分割する特徴番号
        threshold  : 分割するしきい値
        label      : 割り当てられたクラス番号
        numdata    : 割り当てられたデータ数
        gini_index : 分割指数(Giniインデックス)
        """

        self.left = None
        self.right = None
        self.feature = None
        self.threshold = None
        self.label = None
        self.numdata = None
        self.gini_index = None

    def build(self, data, target):
        """木の構築を行う

        data   : ノードに与えられたデータ
        target : データの分類クラス
        """

        self.numdata = data.shape[0]
        num_features = data.shape[1]

        # 全データが同一クラスとなったら分割終了
        if len(np.unique(target)) == 1:
            self.label = target[0]
            return

        # 自分のクラスを設定(各データの多数決)
        class_cnt = {i: len(target[target==i]) for i in np.unique(target)}
        self.label= max(class_cnt.items(), key=lambda x:x[1])[0]

        # 最良の分割を記憶する変数
        best_gini_index = 0.0 # 不純度変化なし
        best_feature = None
        best_threshold = None

        # 自分の不純度は先に計算しておく
        gini = self.gini_func(target)

        for f in range(num_features):

            # 分割候補の計算
            data_f = np.unique(data[:, f]) # f番目の特徴量(重複排除)
            points = (data_f[:-1] + data_f[1:]) / 2.0 # 中間の値を計算

            # 各分割を試す
            for threshold in points:

                # しきい値で2グループに分割
                target_l = target[data[:, f] <  threshold]
                target_r = target[data[:, f] >= threshold]

                # 分割後の不純度からGiniインデックスを計算
                gini_l = self.gini_func(target_l)
                gini_r = self.gini_func(target_r)
                pl = float(target_l.shape[0]) / self.numdata
                pr = float(target_r.shape[0]) / self.numdata
                gini_index = gini - (pl * gini_l + pr * gini_r)

                # より良い分割であれば記憶しておく
                if gini_index > best_gini_index:
                    best_gini_index = gini_index
                    best_feature = f
                    best_threshold = threshold

        # 不純度が減らなければ終了
        if best_gini_index == 0:
            return

        # 最良の分割を保持する
        self.feature = best_feature
        self.gini_index = best_gini_index
        self.threshold = best_threshold

        # 左右の子を作って再帰的に分割させる
        data_l   =   data[data[:, self.feature] <  self.threshold]
        target_l = target[data[:, self.feature] <  self.threshold]
        self.left = _Node()
        self.left.build(data_l, target_l)

        data_r   =   data[data[:, self.feature] >= self.threshold]
        target_r = target[data[:, self.feature] >= self.threshold]
        self.right = _Node()
        self.right.build(data_r, target_r)

    def gini_func(self, target):
        """Gini関数の計算

        target : 各データの分類クラス
        """
        classes = np.unique(target)
        numdata = target.shape[0]

        # Gini関数本体
        gini = 1.0
        for c in classes:
            gini -= (len(target[target == c]) / numdata) ** 2.0

        return gini

    def prune(self, criterion, numall):
        """木の剪定を行う

        criterion  : 剪定条件(この数以下は剪定対象)
        numall    : 全ノード数
        """

        #自分が葉ノードであれば終了
        if self.feature == None:
            return

        # 子ノードの剪定
        self.left.prune(criterion, numall)
        self.right.prune(criterion, numall)

        # 子ノードが両方葉であれば剪定チェック
        if self.left.feature == None and self.right.feature == None:

            # 分割の貢献度:GiniIndex * (データ数の割合)
            result = self.gini_index * float(self.numdata) / numall

            # 貢献度が条件に満たなければ剪定する
            if result < criterion:
                self.feature = None
                self.left = None
                self.right = None

    def predict(self, d):
        """入力データ(単一)の分類先クラスを返す"""

        # 自分が節の場合は条件判定
        if self.feature != None:
            if d[self.feature] < self.threshold:
                return self.left.predict(d)
            else:
                return self.right.predict(d)

        # 自分が葉の場合は自分の分類クラスを返す
        else:
            return self.label

    def print_tree(self, depth, TF):
        """分類条件を出力する"""

        head = "    " * depth + TF + " -> "

        # 節の場合
        if self.feature != None:
            print head + str(self.feature) + " < " + str(self.threshold) + "?"
            self.left.print_tree(depth + 1, "T")
            self.right.print_tree(depth + 1, "F")

        # 葉の場合
        else:
            print head + "{" + str(self.label) + ": " + str(self.numdata) + "}"


class DecisionTree:
    """CARTによる分類木学習器"""

    def __init__(self, criterion=0.1):
        """初期化処理

        root : 決定木のルートノード
        criterion : 剪定の条件
            (a) criterion(大) -> 木が浅くなる
            (b) criterion(小) -> 木が深くなる
        """
        self.root = None
        self.criterion = criterion

    def fit(self, data, target):
        """学習を行い決定木を構築する

        data   : 学習データ
        target : 各データの分類クラス
        """
        self.root = _Node()
        self.root.build(data, target)
        self.root.prune(self.criterion, self.root.numdata)
        pass

    def predict(self, data):
        """分類クラスの予測を行う

        data : テストデータ
        """
        ans = []
        for d in data:
            ans.append(self.root.predict(d))
        return np.array(ans)

    def print_tree(self):
        """分類木の情報を表示する"""
        self.root.print_tree(0, " ")