[회고] 신입 iOS 개발자가 되기까지 feat. 카카오 자세히보기

🛠 기타/Data & AI

[scikit-learn 라이브러리] DecisionTreeClassifier (결정트리분류기)

inu 2020. 8. 5. 00:41

결정트리

  • 분할과 가지치기 과정을 반복하면서 모델을 생성한다.
  • 결정트리에는 분류와 회귀 모두에 사용할 수 있다.
  • 여러개의 모델을 함께 사용하는 앙상블 모델이 존재한다. (RandomForest, GradientBoosting, XGBoost)
  • 각 특성이 개별 처리되기 때문에 데이터 스케일에 영향을 받지 않아 특성의 정규화나 표준화가 필요 없다.
  • 시계열 데이터와 같이 범위 밖의 포인트는 예측 할 수 없다.
  • 과대적합되는 경향이 있다. 이는 본문에 소개할 가지치기 기법을 사용해도 크게 개선되지 않는다.

DecisionTreeClassifier()

DecisionTreeClassifier(criterion, splitter, max_depth, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_features, random_state, max_leaf_nodes, 
min_impurity_decrease, min_impurity_split, class_weight, presort)
  • criterion : 분할 품질을 측정하는 기능 (default : gini)
  • splitter : 각 노드에서 분할을 선택하는 데 사용되는 전략 (default : best)
  • max_depth : 트리의 최대 깊이 (값이 클수록 모델의 복잡도가 올라간다.)
  • min_samples_split : 자식 노드를 분할하는데 필요한 최소 샘플 수 (default : 2)
  • min_samples_leaf : 리프 노드에 있어야 할 최소 샘플 수 (default : 1)
  • min_weight_fraction_leaf : min_sample_leaf와 같지만 가중치가 부여된 샘플 수에서의 비율
  • max_features : 각 노드에서 분할에 사용할 특징의 최대 수
  • random_state : 난수 seed 설정
  • max_leaf_nodes : 리프 노드의 최대수
  • min_impurity_decrease : 최소 불순도
  • min_impurity_split : 나무 성장을 멈추기 위한 임계치
  • class_weight : 클래스 가중치
  • presort : 데이터 정렬 필요 여부

사용 예제

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# 데이터 로드
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=11)

# 모델 학습
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
  • iris 데이터를 활용해 학습 데이터와 테스트 데이터를 만든다.
  • random_state 값만 주어 model을 만들고 학습 데이터를 fitting 시킨다.
# 결정트리 규칙을 시각화
import matplotlib.pyplot as plt
from sklearn import tree

plt.figure( figsize=(20,15) )
tree.plot_tree(model, 
               class_names=iris.target_names,
               feature_names=iris.feature_names,
               impurity=True, filled=True,
               rounded=True)
  • sklearn의 tree 모듈을 활용해 완성된 결정트리를 그린다.

  • 모두 분류될 때까지 (지니불순도가 0이 될때까지 리프 노드를 확장한 것을 확인할 수 있다.

가지치기

  • 위에서 만든 결정트리는 학습 데이터에 완전 적합(과적합)되어 있다. 따라서 다른 데이터엔 적절하게 사용될 수 없을 것이다.
  • 따라서 적절히 가치지기를 수행한다.
model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, y_train)

# 결정트리 규칙 시각화
plt.figure( figsize=(20,15) )
tree.plot_tree(model, 
               class_names=iris.target_names,
               feature_names=iris.feature_names,
               impurity=True, filled=True,
               rounded=True)
  • max_depth를 3으로 주고 시각화를 진행했다.

  • 깊이 3일 때까지만 분류를 진행했음을 알 수 있다.
  • 이 외에도 min_samples_split(리프 노드가 될 수 있는 샘플 데이터 최소값), min_samples_leaf(리프 노드가 될 수 있는 샘플 데이터 최소값), max_leaf_nodes (리프 노드가 될 수 있는 샘플 데이터 최대값) 등을 조절해 가지치기를 할 수 있다.

특성 중요도

  • 트리 분류 모델 형성에 각 특성이 얼마나 작용했는지 평가하는 지표이다.
  • 해당 지표는 0~1사이의 값을 가지며, 0이면 특성이 전혀 작용하지 않았음을 의미하고 1은 완전하게 작용했음을 의미한다.
  • 특성 중요도 전체의 합은 1이다.
import seaborn as sns
import numpy as np
%matplotlib inline

# feature별 importance 매핑
for name, value in zip(iris.feature_names , model.feature_importances_):
    print('{} : {:.3f}'.format(name, value))

# feature importance를 column 별로 시각화 하기 
sns.barplot(x=model.feature_importances_ , y=iris.feature_names)
==결과==
sepal length (cm) : 0.006
sepal width (cm) : 0.000
petal length (cm) : 0.546
petal width (cm) : 0.448

  • 특정 모델 트리의 feature_importances_ 변수를 확인함으로서 특성 중요도를 체크할 수 있다.
  • (cf. 해당 특성중요도는 위의 트리와는 연관이 없다.)