DecisionTreeClassifierの使い方
公式ドキュメント
- sklearn.tree.DecisionTreeClassifier — scikit-learn 0.18.1 documentation
パラメータ
- criterion
- splitter
- max_features
- max_depth
- min_samples_split
- min_samples_leaf
- min_weight_fraction_leaf
- max_leaf_nodes
- class_weight
- random_state
- min_impurity_split
- presort
パラメータを変えて様子をみる。
サンプルデータ
decision treeで分類しやすいように格子状のデータを作成する。
パラメータを変えて実験
デフォルトのパラメータのまま
- 分類結果
- 分岐ツリー
- 重要度
{'x1': 0.69674185463659166, 'x0': 0.30325814536340839}
- ちゃんとすべて分類できている。線形で分類できるからな。
- feature importance で本当はx0とx1の重要度は同じはず、ただ、最初の分岐がx0だったので、x0の方が高くなったいる。使用する乱数によって変わる。
criterion=‘entropy’
- 分類結果
- 分岐ツリー
- 今回のケースだと何もかわらず。。
- どう結果に影響するかを他のケースを作って検証しよう(また今度)
splitter=‘random’
- 分類結果
- 分岐ツリー
- 重要度
{'x1': 0.51479236812570128, 'x0': 0.48520763187429866}
- 境界をランダムに決めているので、分岐ツリーが巨大になっている。
- 重要度は近い値に。
max_features=1
- 今、特徴数は2なので、1とした場合をみる。2の場合は良い方で分岐するが、1だと乱数次第。
- 分類結果
- 分岐ツリー
- 重要度
{'x0': 0.052631578947368363, 'x1': 0.94736842105263164}
- 複雑な分岐になる。
- 重要度も乱数次第なので、当てにならない。
max_depth=3
- ツリーの深さの最大を3とする。3に達したらそこで終了
- 分類結果
- 分岐ツリー
min_samples_split=80
- nodeに残っているサンプル数が80個未満になったら分岐せずにそこで終了。
- 分類結果
- 分岐ツリー
min_samples_leaf=11
- リーフに残るサンプル数の最低数を指定
- 分類結果
- 分岐ツリー
min_weight_fraction_leaf = 6 / 400
- class_weightを指定していないので、min_samples_leaf=6と同じ。
- 分類結果
- 分岐ツリー
max_leaf_nodes=10
- リーフの最大個数を指定。今回は10個のリーフができている。
- 分類結果
- 分岐ツリー
min_impurity_split=0.48
- criterionが指定した値以下になったら、そこで分岐終了
- 分類結果
- 分岐ツリー
まとめ
- 今回は各パラメータの意味を確認できた。
- キレイに分類できるようにデータを作っている&予測をしていないので、当然、パラメータ調整しても意味がない。
ソースコードの一部
import numpy as np import pandas as pd from sklearn import tree import nsb import pydotplus def grid_data(): n = 20 x0, x1 = np.meshgrid(np.linspace(0, 2, n), np.linspace(0, 2, n)) def is_even(x): return np.round(x) % 2 == 0 y = np.logical_xor(is_even(x0), np.logical_not(is_even(x1))).astype(np.int64) return pd.DataFrame({'x0': x0.ravel(), 'x1': x1.ravel(), 'y': y.ravel()}) if __name__ == '__main__': df = grid_data() clf = tree.DecisionTreeClassifier(min_impurity_split=0.48) clf.fit(df[['x0', 'x1']], df['y']) nsb.plot.scatter_with_boundary(df['x0'], df['x1'], df['y'], clf, 'fig.png') dot_data = tree.export_graphviz(clf, out_file=None, feature_names=['x0', 'x1']) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_png('tree.png') print(dict(zip(['x0', 'x1'], clf.feature_importances_)))