DecisionTreeClassifierの使い方

公式ドキュメント

パラメータを変えて様子をみる。

サンプルデータ

decision treeで分類しやすいように格子状のデータを作成する。 f:id:nsb248:20170224141608p:plain

パラメータを変えて実験

デフォルトのパラメータのまま

  • 分類結果 f:id:nsb248:20170224141857p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224141902p:plain:w200
  • 重要度
{'x1': 0.69674185463659166, 'x0': 0.30325814536340839}
  • ちゃんとすべて分類できている。線形で分類できるからな。
  • feature importance で本当はx0とx1の重要度は同じはず、ただ、最初の分岐がx0だったので、x0の方が高くなったいる。使用する乱数によって変わる。

criterion=‘entropy’

  • 分類結果 f:id:nsb248:20170224144138p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224144147p:plain:w200
  • 今回のケースだと何もかわらず。。
  • どう結果に影響するかを他のケースを作って検証しよう(また今度)

splitter=‘random’

  • 分類結果 f:id:nsb248:20170224144549p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224144600p:plain:w200
  • 重要度
{'x1': 0.51479236812570128, 'x0': 0.48520763187429866}
  • 境界をランダムに決めているので、分岐ツリーが巨大になっている。
  • 重要度は近い値に。

max_features=1

  • 今、特徴数は2なので、1とした場合をみる。2の場合は良い方で分岐するが、1だと乱数次第。
  • 分類結果 f:id:nsb248:20170224145125p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224145138p:plain:w200
  • 重要度
{'x0': 0.052631578947368363, 'x1': 0.94736842105263164}
  • 複雑な分岐になる。
  • 重要度も乱数次第なので、当てにならない。

max_depth=3

  • ツリーの深さの最大を3とする。3に達したらそこで終了
  • 分類結果 f:id:nsb248:20170224150236p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224150239p:plain:w200

min_samples_split=80

  • nodeに残っているサンプル数が80個未満になったら分岐せずにそこで終了。
  • 分類結果 f:id:nsb248:20170224150033p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224150035p:plain:w200

min_samples_leaf=11

  • リーフに残るサンプル数の最低数を指定
  • 分類結果 f:id:nsb248:20170224151055p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224151058p:plain:w200

min_weight_fraction_leaf = 6 / 400

  • class_weightを指定していないので、min_samples_leaf=6と同じ。
  • 分類結果 f:id:nsb248:20170224151323p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224151325p:plain:w200

max_leaf_nodes=10

  • リーフの最大個数を指定。今回は10個のリーフができている。
  • 分類結果 f:id:nsb248:20170224151457p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224151501p:plain:w200

min_impurity_split=0.48

  • criterionが指定した値以下になったら、そこで分岐終了
  • 分類結果 f:id:nsb248:20170224151817p:plain:w200
  • 分岐ツリー f:id:nsb248:20170224151819p:plain:w200

まとめ

  • 今回は各パラメータの意味を確認できた。
  • キレイに分類できるようにデータを作っている&予測をしていないので、当然、パラメータ調整しても意味がない。

ソースコードの一部

import numpy as np
import pandas as pd
from sklearn import tree
import nsb
import pydotplus


def grid_data():
    n = 20
    x0, x1 = np.meshgrid(np.linspace(0, 2, n), np.linspace(0, 2, n))

    def is_even(x):
        return np.round(x) % 2 == 0

    y = np.logical_xor(is_even(x0), np.logical_not(is_even(x1))).astype(np.int64)
    return pd.DataFrame({'x0': x0.ravel(), 'x1': x1.ravel(), 'y': y.ravel()})


if __name__ == '__main__':
    df = grid_data()
    clf = tree.DecisionTreeClassifier(min_impurity_split=0.48)
    clf.fit(df[['x0', 'x1']], df['y'])
    nsb.plot.scatter_with_boundary(df['x0'], df['x1'], df['y'], clf, 'fig.png')
    dot_data = tree.export_graphviz(clf, out_file=None, feature_names=['x0', 'x1'])
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_png('tree.png')

    print(dict(zip(['x0', 'x1'], clf.feature_importances_)))