[회고] 신입 iOS 개발자가 되기까지 feat. 카카오 자세히보기

🛠 기타/Data & AI

[scikit-learn 라이브러리] 교차 검증

inu 2020. 8. 23. 16:04
반응형

교차 검증 (Cross Validation)

  • 일반화 성능 향상을 위해 훈련 세트와 테스트 세트를 한 번만 나누는 것보다 더 안정적이고 뛰어난 평가 방법이다.
  • 여러개의 세트로 구성된 학습 데이터와 테스트 데이터로 학습과 평가를 수행한다.
  • k-겹 교차검증 : 데이터를 폴드(fold)라는 거의 비슷한 크기의 부분집합 k개로 분리하고 각 부분집합의 정확도를 측정한다.
  • 교차 검증의 점수가 높을수록 데이터셋에 있는 모든 샘플에 대해 모델이 잘 일반화되게 된다.
  • 하지만 연산비용이 늘어나게 된다는 단점이 있다.
  • scikit-learn에서 교차 검증은 model_selection 모듈의 cross_val_score라는 함수로 구현되어 있다.

cross_val_score()

cross_val_score(estimator, X, y=None, *, groups=None, scoring=None, cv=None, n_jobs=None, verbose=0, fit_params=None, pre_dispatch='2*n_jobs', error_score=nan)
  • estimator : 평가하려는 모델
  • X : 훈련 데이터
  • y : 타깃 레이블
  • cv : 교차 검증 분할 수 (default: 5)
  • 리턴값 : 교차 검증 결과 정확도 점수의 배열
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
model_lr = LogisticRegression()
model_dt = DecisionTreeClassifier(random_state=0)

scores = cross_val_score(model_lr, iris.data, iris.target)
print("교차 검증 점수: ", scores)

scores = cross_val_score(model_dt, iris.data, iris.target)
print("교차 검증 점수: ", scores)
==결과==
교차 검증 점수:  [0.96666667 1.         0.93333333 0.96666667 1.        ]
교차 검증 점수:  [0.96666667 0.96666667 0.9        0.96666667 1.        ]

cross_validate()

  • 같은 교차검증 함수지만, 여러 기준으로 정확도를 구해 리턴하게 된다.
from sklearn.model_selection import cross_validate

res = cross_validate(model_lr, iris.data, iris.target, return_train_score=True)
res
==결과==
{'fit_time': array([0.01603723, 0.016011  , 0.0120194 , 0.01600003, 0.01199794]),
 'score_time': array([0., 0., 0., 0., 0.]),
 'test_score': array([0.96666667, 1.        , 0.93333333, 0.96666667, 1.        ]),
 'train_score': array([0.96666667, 0.96666667, 0.98333333, 0.98333333, 0.975     ])}

계층별 k-겹 교차검증

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
  • 위와 같은 데이터를 단순히 교차검증하면, 각 폴드에 클래스들이 몰리게 되어 정확한 학습 및 테스트가 어려워진다.
  • 따라서 데이터를 세밀하게 분할하고 섞어줄 수 있는 kfold를 사용한다.
from sklearn.model_selection import KFold

kfold = KFold(n_splits=3, shuffle=True, random_state=0)
scores = cross_val_score(model_lr, iris.data, iris.target, cv=kfold)
print("교차 검증 점수: {}".format(scores))
  • cv값에 kfold 객체를 넣을 수 있다.
  • shuffle값을 True로 주어 데이터를 랜덤하게 배분한다.

  • 이 외에도 GroupKFold(), ShuffleSplit() 등 다양한 분할 기법이 존재한다.
반응형