Notice
Recent Posts
Recent Comments
Link
«   2025/08   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
31
Tags
more
Archives
Today
Total
관리 메뉴

kang's study

10일차 : 교차검증과 그리드 서치 본문

[학습 공간]/[혼공머신러닝]

10일차 : 교차검증과 그리드 서치

보끔밥0302 2022. 3. 3. 06:53

교차 검증(cross validation)과 그리드 서치(Grid Search)

검증 세트 (validation set)
 
 
데이터의 열 중 타깃 배열과 특성 배열을 구분
In [31]:
import pandas as pd

wine = pd.read_csv('https://bit.ly/wine_csv_data')
 
데이터를 훈련세트와 테스트 세트로 나누기
In [32]:
data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()
In [33]:
from sklearn.model_selection import train_test_split

train_input, test_input, train_target, test_target = train_test_split(
    data, target, test_size=0.2, random_state=42)
# 테스트 세트는 최종 연습 성능예측용
In [34]:
sub_input, val_input, sub_target, val_target = train_test_split(
    train_input, train_target, test_size=0.2, random_state=42)
# 훈련데이터의 일부를 다시 검증 세트로 분할
# train -> sub (훈련용), value(검증용)
# v 매개변수 튜닝하여 best parameter 찾기
# s+v다시 묶어 최종 훈련한 모델로 테스트 하기
In [35]:
print(sub_input.shape, val_input.shape)
 
(4157, 3) (1040, 3)
 
모델 만들기
In [72]:
from sklearn.tree import DecisionTreeClassifier, export_text

dt = DecisionTreeClassifier(random_state=42)
dt.fit(sub_input, sub_target)

print(dt.score(sub_input, sub_target))
print(dt.score(val_input, val_target))
 
0.9971133028626413
0.864423076923077
 

교차 검증 (cross validation)

(머신러닝에서는 중요: 데이터가 충분하지 않음)
k겹 교차 검증 : 훈련 세트를 몇 부분으로 분할하느냐에 따름
딥러닝은 데이터가 충분하고 너무 많아서 교차검증이 효율적이지 않음
 
 
 
교차검증
함수

단, 훈련 세트를 섞어 폴드를 나누지 않는다.
앞서 train_test_split()로 데이터가 섞여서 상관 없으나 만약 교차검증을 수행하려면 분할기(splitter)를 지정해야함

In [41]:
from sklearn.model_selection import cross_validate 

scores = cross_validate(dt, train_input, train_target) 
# 기본 5폴드 교차검증을 수행, 검증폴드의 점수 test_score
# 결정트리의 검증 점수가 마음에 안들면 트리의 매개변수를 바꾸어서 훈련할 수 있다
print(scores)
 
{'fit_time': array([0.00598407, 0.00598454, 0.00593948, 0.00498652, 0.00496697]), 'score_time': array([0.        , 0.00106192, 0.00099754, 0.00099707, 0.        ]), 'test_score': array([0.86923077, 0.84615385, 0.87680462, 0.84889317, 0.83541867])}
In [42]:
import numpy as np

print(np.mean(scores['test_score']))
 
0.855300214703487
In [43]:
# test_score값만 반환
from sklearn.model_selection import cross_val_score

scores2 = cross_val_score(dt, train_input, train_target) 
print(scores2)
 
[0.86923077 0.84615385 0.87680462 0.84889317 0.83541867]
 

분할기 (splitter)

회귀모델은 KFold, 분류모델은 StratifiedKFold

In [44]:
from sklearn.model_selection import StratifiedKFold

scores = cross_validate(dt, train_input, train_target, cv=StratifiedKFold()) 
# 타깃 클래스를 골고루 나눔
print(np.mean(scores['test_score']))
 
0.855300214703487
 

스플리트 객체

몇 폴드 교차 검증을 할지 정해줌

In [45]:
splitter = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
scores = cross_validate(dt, train_input, train_target, cv=splitter)
print(np.mean(scores['test_score']))
 
0.8574181117533719
 

그리드 서치 (교차 검증 반복작업)

하이퍼파라미터 탐색과 교차 검증을 한 번에 수행
 

하이퍼파라미터 튜닝

 

하이퍼 파라미터 : 모델이 학습할 수 없어서 사용자가 지정해야하는 매개변수
모델 파라미터 : 머신러닝 모델이 학습하는 파라미터

 

max_depth와 min_inpurity_decrease라는 매개변수가 있다.
두 매개변수가 상호간에 영향을 미치므로 순서대로 찾지못한다
동시에 여러 개의 매개변수를 두고 찾아야한다.

In [46]:
from sklearn.model_selection import GridSearchCV
# 여러 개의 매개변수를 바꾸면서 교차검증 하게 해줌 

params = {'min_impurity_decrease': [0.0001, 0.0002, 0.0003, 0.0004, 0.0005]}
# 매개변수를 딕셔너리로 지정
# 정보이득 : 부모와 자식 노드 간의 불순도 차이 (차이가 클수록 분할을 잘한 것)
# 정보이득의 최솟값 정함 (트리의 성장을 막아줌)
In [47]:
gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params, n_jobs=-1)
# 5개의 모델에 5개의매개변수 총 25개의 모델을 훈련
# n_jobs 동시 훈련 -1은 가능한 모든 코어 사용하여 훈련
In [48]:
gs.fit(train_input, train_target)
# 객체를 만들고 fit사용
Out[48]:
GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,
             param_grid={'min_impurity_decrease': [0.0001, 0.0002, 0.0003,
                                                   0.0004, 0.0005]})
In [49]:
dt = gs.best_estimator_ # 최적 모델
print(dt.score(train_input, train_target))
 
0.9615162593804117
In [50]:
print(gs.best_params_) # 매개변수
 
{'min_impurity_decrease': 0.0001}
In [51]:
print(gs.cv_results_['mean_test_score']) # 5번의 교차검증 평균점수 
 
[0.86819297 0.86453617 0.86492226 0.86780891 0.86761605]
In [52]:
best_index = np.argmax(gs.cv_results_['mean_test_score']) # 가장 큰 값의 인덱스
print(gs.cv_results_['params'][best_index])
 
{'min_impurity_decrease': 0.0001}
In [53]:
params = {'min_impurity_decrease': np.arange(0.0001, 0.001, 0.0001),
          'max_depth': range(5, 20, 1),
          'min_samples_split': range(2, 100, 10)  
         } # 9*15*10*5fold = 6750 번 훈련
# min_samples_split은 노드를 나누기 위한 최소 샘플 수
In [54]:
gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params, n_jobs=-1)
gs.fit(train_input, train_target)
Out[54]:
GridSearchCV(estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,
             param_grid={'max_depth': range(5, 20),
                         'min_impurity_decrease': array([0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008,
       0.0009]),
                         'min_samples_split': range(2, 100, 10)})
In [55]:
print(gs.best_params_)
 
{'max_depth': 14, 'min_impurity_decrease': 0.0004, 'min_samples_split': 12}
In [56]:
print(np.max(gs.cv_results_['mean_test_score']))
 
0.8683865773302731
 

랜덤 서치

 

매개변수 값이 수치일 때 값의 범위나 간격 설정의 어려움
확률분포 이용

In [57]:
from scipy.stats import uniform, randint # 균등분포 샘플링(실수, 정수)
In [58]:
rgen = randint(0, 10)
rgen.rvs(10) # 랜덤 샘플링 (마지막 포함하지 않음)
Out[58]:
array([0, 2, 4, 9, 9, 4, 0, 4, 4, 0])
In [59]:
np.unique(rgen.rvs(1000), return_counts=True) 
Out[59]:
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([105, 109,  98, 105,  99,  85,  86, 103, 114,  96], dtype=int64))
In [60]:
ugen = uniform(0, 1)
ugen.rvs(10)
Out[60]:
array([0.75764583, 0.21716498, 0.28616839, 0.69301966, 0.60381648,
       0.29050307, 0.37064949, 0.61438674, 0.99894703, 0.75818834])
 

그리드 서치 : 일정한 간격으로 테스트 해준다
촘촘하게 하면 테스트 개수가 많아진다
범위안에 랜덤하게 샘플링해서 모델을 테스트하나 모델을 만드는 횟수를 제한하여 넓은 분포를 잘 탐색할 수 있다.

In [61]:
params = {'min_impurity_decrease': uniform(0.0001, 0.001),
          'max_depth': randint(20, 50),
          'min_samples_split': randint(2, 25),
          'min_samples_leaf': randint(1, 25),
          } # 결정트리라 결정트리의 매개변수를 정해놓았다
In [62]:
from sklearn.model_selection import RandomizedSearchCV

gs = RandomizedSearchCV(DecisionTreeClassifier(random_state=42), params, 
                        n_iter=100, n_jobs=-1, random_state=42)
# random_state 출력결과 동일, n_iter 샘플링 횟수
gs.fit(train_input, train_target)
Out[62]:
RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42),
                   n_iter=100, n_jobs=-1,
                   param_distributions={'max_depth': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACED31820>,
                                        'min_impurity_decrease': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACE5877F0>,
                                        'min_samples_leaf': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACDD52790>,
                                        'min_samples_split': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACED23880>},
                   random_state=42)
In [63]:

 

print(gs.best_params_)
 
{'max_depth': 39, 'min_impurity_decrease': 0.00034102546602601173, 'min_samples_leaf': 7, 'min_samples_split': 13}
In [64]:
print(np.max(gs.cv_results_['mean_test_score'])) # best파라미터로 찾은 검증테스트 결과
 
0.8695428296438884
In [65]:
dt = gs.best_estimator_ # best파라미터로 전체 훈련데이터로 훈련

print(dt.score(test_input, test_target))
 
0.86

 

In [66]:
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
In [69]:
plt.figure(figsize=(20,15))
plot_tree(dt, filled = True, feature_names=['alcohol','sugar','pH'])
plt.show()
 
In [80]:
print(export_text(dt, feature_names = ['alcohol', 'sugar', 'pH'], max_depth=5))
 
|--- sugar <= 4.05
|   |--- sugar <= 1.62
|   |   |--- sugar <= 1.38
|   |   |   |--- pH <= 3.84
|   |   |   |   |--- pH <= 2.90
|   |   |   |   |   |--- sugar <= 0.95
|   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |   |--- sugar >  0.95
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |--- pH >  2.90
|   |   |   |   |   |--- sugar <= 1.18
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |--- sugar >  1.18
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |--- pH >  3.84
|   |   |   |   |--- class: 0.0
|   |   |--- sugar >  1.38
|   |   |   |--- alcohol <= 10.05
|   |   |   |   |--- pH <= 3.25
|   |   |   |   |   |--- sugar <= 1.48
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- sugar >  1.48
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |--- pH >  3.25
|   |   |   |   |   |--- alcohol <= 9.58
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |--- alcohol >  9.58
|   |   |   |   |   |   |--- truncated branch of depth 9
|   |   |   |--- alcohol >  10.05
|   |   |   |   |--- pH <= 3.45
|   |   |   |   |   |--- alcohol <= 11.25
|   |   |   |   |   |   |--- truncated branch of depth 11
|   |   |   |   |   |--- alcohol >  11.25
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |   |   |--- pH >  3.45
|   |   |   |   |   |--- alcohol <= 10.85
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- alcohol >  10.85
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |--- sugar >  1.62
|   |   |--- alcohol <= 11.08
|   |   |   |--- pH <= 3.24
|   |   |   |   |--- sugar <= 3.25
|   |   |   |   |   |--- alcohol <= 9.75
|   |   |   |   |   |   |--- truncated branch of depth 15
|   |   |   |   |   |--- alcohol >  9.75
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |--- sugar >  3.25
|   |   |   |   |   |--- pH <= 3.19
|   |   |   |   |   |   |--- truncated branch of depth 9
|   |   |   |   |   |--- pH >  3.19
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |--- pH >  3.24
|   |   |   |   |--- alcohol <= 9.95
|   |   |   |   |   |--- sugar <= 2.65
|   |   |   |   |   |   |--- truncated branch of depth 12
|   |   |   |   |   |--- sugar >  2.65
|   |   |   |   |   |   |--- truncated branch of depth 10
|   |   |   |   |--- alcohol >  9.95
|   |   |   |   |   |--- pH <= 3.61
|   |   |   |   |   |   |--- truncated branch of depth 13
|   |   |   |   |   |--- pH >  3.61
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |--- alcohol >  11.08
|   |   |   |--- pH <= 3.38
|   |   |   |   |--- pH <= 3.11
|   |   |   |   |   |--- sugar <= 2.05
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |--- sugar >  2.05
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |   |   |--- pH >  3.11
|   |   |   |   |   |--- alcohol <= 12.55
|   |   |   |   |   |   |--- truncated branch of depth 18
|   |   |   |   |   |--- alcohol >  12.55
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |--- pH >  3.38
|   |   |   |   |--- sugar <= 3.05
|   |   |   |   |   |--- alcohol <= 11.95
|   |   |   |   |   |   |--- truncated branch of depth 8
|   |   |   |   |   |--- alcohol >  11.95
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |--- sugar >  3.05
|   |   |   |   |   |--- pH <= 3.55
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- pH >  3.55
|   |   |   |   |   |   |--- class: 0.0
|--- sugar >  4.05
|   |--- sugar <= 6.15
|   |   |--- pH <= 3.27
|   |   |   |--- sugar <= 4.65
|   |   |   |   |--- pH <= 3.07
|   |   |   |   |   |--- pH <= 3.00
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- pH >  3.00
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |--- pH >  3.07
|   |   |   |   |   |--- pH <= 3.19
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |--- pH >  3.19
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |--- sugar >  4.65
|   |   |   |   |--- alcohol <= 9.65
|   |   |   |   |   |--- alcohol <= 9.53
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- alcohol >  9.53
|   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |--- alcohol >  9.65
|   |   |   |   |   |--- pH <= 3.19
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- pH >  3.19
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |--- pH >  3.27
|   |   |   |--- sugar <= 5.45
|   |   |   |   |--- sugar <= 4.67
|   |   |   |   |   |--- alcohol <= 10.85
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |   |--- alcohol >  10.85
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |--- sugar >  4.67
|   |   |   |   |   |--- pH <= 3.32
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |   |--- pH >  3.32
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |--- sugar >  5.45
|   |   |   |   |--- pH <= 3.44
|   |   |   |   |   |--- pH <= 3.32
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- pH >  3.32
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |   |--- pH >  3.44
|   |   |   |   |   |--- class: 1.0
|   |--- sugar >  6.15
|   |   |--- pH <= 3.26
|   |   |   |--- sugar <= 6.58
|   |   |   |   |--- sugar <= 6.53
|   |   |   |   |   |--- pH <= 3.21
|   |   |   |   |   |   |--- class: 1.0
|   |   |   |   |   |--- pH >  3.21
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |--- sugar >  6.53
|   |   |   |   |   |--- class: 0.0
|   |   |   |--- sugar >  6.58
|   |   |   |   |--- pH <= 3.16
|   |   |   |   |   |--- alcohol <= 12.25
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- alcohol >  12.25
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |--- pH >  3.16
|   |   |   |   |   |--- pH <= 3.17
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- pH >  3.17
|   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |--- pH >  3.26
|   |   |   |--- sugar <= 8.95
|   |   |   |   |--- pH <= 3.27
|   |   |   |   |   |--- alcohol <= 11.40
|   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- alcohol >  11.40
|   |   |   |   |   |   |--- class: 0.0
|   |   |   |   |--- pH >  3.27
|   |   |   |   |   |--- sugar <= 8.70
|   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |--- sugar >  8.70
|   |   |   |   |   |   |--- truncated branch of depth 5
|   |   |   |--- sugar >  8.95
|   |   |   |   |--- alcohol <= 10.45
|   |   |   |   |   |--- class: 1.0
|   |   |   |   |--- alcohol >  10.45
|   |   |   |   |   |--- alcohol <= 10.55
|   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |--- alcohol >  10.55
|   |   |   |   |   |   |--- class: 1.0

In [70]:
print(dt.feature_importances_)
 
[0.16773734 0.69008328 0.14217938]
 

결정트리 클래스 분할기 random

splitter 매개변수의 기본값은 best : 각 노드에서 최선의 분할
random : 무작위로 분할한 다음 가장 좋은 노드

In [81]:
gs = RandomizedSearchCV(DecisionTreeClassifier(splitter='random', random_state=42), params, 
                        n_iter=100, n_jobs=-1, random_state=42)
gs.fit(train_input, train_target)
Out[81]:
RandomizedSearchCV(estimator=DecisionTreeClassifier(random_state=42,
                                                    splitter='random'),
                   n_iter=100, n_jobs=-1,
                   param_distributions={'max_depth': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACED31820>,
                                        'min_impurity_decrease': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACE5877F0>,
                                        'min_samples_leaf': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACDD52790>,
                                        'min_samples_split': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000015ACED23880>},
                   random_state=42)
In [82]:
print(gs.best_params_)
print(np.max(gs.cv_results_['mean_test_score']))

dt = gs.best_estimator_
print(dt.score(test_input, test_target))
 
{'max_depth': 43, 'min_impurity_decrease': 0.00011407982271508446, 'min_samples_leaf': 19, 'min_samples_split': 18}
0.8458726956392981
0.786923076923077

출처 : 박해선, 『혼자공부하는머신러닝+딥러닝』, 한빛미디어(2021), p242-259

 

Comments