728x90
교차 검증
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
# K폴드 교차 검증
# K개의 데이터 폴드 세트를 만들어서 K번만큼 각 폴드 세트에 학습과 검증을 반복적으로 수행하는 방법
# 쉽게 말해서 5개의 데이터(1,2,3,4,5)가 한세트라고 할때 K가 5이면 1 2 3 4 5로 데이터 세트를 5개로 나눈 다음에 1 2 3 4 학습 5 검증 -> 1 2 3 5 학습 4 검증 -> ... 이렇게 5번에 학습을 해서 나온 값들을 평균을 내는 것이 K폴드 교차 검증이다.
# Stratified K 폴드
# 불균형한 분포도를 가진 레이블 데이터 집합을 위한 K 폴드 방식이다.
# 작은 비율로 레이블 값이 있다면 K 폴드로 랜덤하게 학습 및 테스트 세트의 인덱스를 고르더라도 레이블 값의 비율이 제대로 반영 하지 못하는 경우가 발생, 이를 위해서 Stratified K 폴드는 원본 데이터의 레이블 분포를 먼저 고려 한 뒤 이 분포와 동일하게 학습과 검증 데이터 세트를 분리한다.
# 교차 검증 API - cross_val_score()
from sklearn.model_selection import cross_val_score,cross_validate
iris_data = load_iris()
model = DecisionTreeClassifier(random_state=156)
data = iris_data.data
label = iris_data.target
scores = cross_val_score(model,data,label,scoring='accuracy',cv = 3) # cv는 교차 검증 폴드수를 의미 한다.
print('교차 검증 정확도 :',scores)
print('평균교차 검증 정확도 : ',np.round(np.mean(scores),4))
|
cs |
위와 같이 교차검증으로 정확도가 구해지고 평균을 내서도 구해진다.
파라미터 튜닝
1
2
3
4
5
6
7
8
9
|
# gridsearchcv - 교차검증과 최적 하이퍼 파라미터 튜닝을 한번에
from sklearn.model_selection import GridSearchCV
x_train,x_test,y_train,y_test = train_test_split(iris_data.data,iris_data.target,test_size = 0.2,random_state = 121)
parameters = {'max_depth':[1,2,3],'min_samples_split':[2,3]}
grid_dtree = GridSearchCV(model,param_grid = parameters,cv = 3,refit = True) # refit가 True 일시 가장 최적의 하이퍼 파라미터를 찾은 뒤 입력된 estimator(model)객체를 해당 하이퍼 파라미터로 재학습시킵니다. 디폴드는 True
grid_dtree.fit(x_train,y_train)
df = pd.DataFrame(grid_dtree.cv_results_)
df[['params','mean_test_score','rank_test_score','split0_test_score','split1_test_score','split2_test_score']]
|
cs |
위와 같이 설정한 파라미터들로 수행했을때 결과를 볼 수 있다.
1
2
3
|
estimator = grid_dtree.best_estimator_ # refit을 True로 하면 최적의 model이 best_estimator로 저장된다.
pred = estimator.predict(x_test)
print('테스트 데이터 예측 정확도:',accuracy_score(y_test,pred))
|
cs |
테스트 데이터 예측 정확도: 0.9666666666666667
위와 같이 가장 최적의 모델을 구해서 정확도를 구해보았다.
728x90
'파이썬 머신러닝' 카테고리의 다른 글
[ML/AI][Bitcoin Data Analysis And RNN Model Prediction] (0) | 2022.09.20 |
---|---|
[ML/AI][SimpleRNN with Keras] (2) | 2022.09.20 |
[머신러닝][kaggle 실습- 보험 비용 예측하기] (2) | 2021.02.23 |
[머신러닝][Deep learning을 이용한 XOR문제해결] (0) | 2021.02.14 |
[머신러닝][Classification 알고리즘 실습-Mushroom Classification] (0) | 2021.02.14 |