본문 바로가기
Machine Learning

통계 분석 #8 : 교차 유효성 검사 Cross-Validation with R

by 무적물리 2020. 3. 26.

교차 유효성 검사는 Cross Validation 이라고 합니다. 주어진 데이터의 일부를 학습시켜 모델을 생성하고 나머지 데이터는 모델을 검증하는데 사용합니다. 회귀 모델이나 머신러닝 모델을 만드는 목적은 예측에 있습니다. 모델을 생성하고 예측이 얼마나 잘 맞는지를 확인해야하는데, 일부의 데이터로 모델을 학습시키고 일부의 데이터는 학습될 모델을 검증하는데 사용한하면 예측의 정확도가 좋은 모델을 만들 수 있습니다. 한마디로, 데이터를 나눠서 모델을 만들고 검증하는 방법입니다.




교차 유효성 검사 개념

교차 유효성 검사는 보통 Training Data Set 80%, Test Data Set 20%로 나누어서 모델을 학습시키고 검증합니다. Training Data Set은 모델을 만드는 데이터이며, Test Data Set은 모델을 검증하는데 사용합니다. 이러한 유효성 검증은 과적합이라는 Overfitting을 방지하기 위해서 사용합니다. 학습한 데이터만을 예측하는 것이 아닌 학습에 사용되지 않은 데이터 역시 예측 가능한 모델을 만드는데 사용합니다. 


교차 유효성 검사 종류

교차 유효성 검사는 크게 세가지로 나눌 수 있습니다. 위에서 설명한 가장 기본적인 Cross-Validation과 LOOCV, k-Fold Cross-Validation이 종류에 해당합니다. 위 세가지 종류 중 데이터와 상황에 맞은 교차 유효성 검증 종류를 선택해서 사용하시면 됩니다. 데이터가 적으면 한가지 데이터만 남기고 학습시키는 것을 n번 반복하는 LOOCV가 유리합니다.


▼ Cross-Validation은 n개의 데이터를 n등분해서 데이터를 훈련용, 테스트용 데이터로 나누어서 모델을 학습하고 평가하는 기본적인 방법입니다. 머신러닝을 배울 때, 가장 먼저 트레이닝 데이터와 테스트 데이터로 데이터를 나누는 것을 배우는데 그 방법의 이름이 Cross-Validation 입니다.



Cross-Validation

장점 : 구현이 간단, 결과 추출 시간이 짧음

단점 : 추출 방법에 영향이 큼, 모델을 정확히 평가하기 어려움


▼ LOOCV라고 불리는 Leave-One-Out Cross-Validation은 n개의 데이터 중 하나의 데이터만 테스트 용으로 사용하고 나머지 데이터로 학습시키는 방법입니다. n번 반복한 후 결과치들의 평균을 도출하여 사용합니다.



LOOCV

장점 : 결과 값이 비교적 정확, 비편향적

단점 : 계산량이 많아 시간이 오래 걸림, 분산이 큼


▼ k-Fold Cross-Validation은 LOOCV의 과다한 연산량을 줄여주는 방법입니다. 데이터를 무작위로 섞은 뒤, k등분을 하여 하나 등분을 테스트 데이터로 사용하는 방법입니다. 이 역시 LOOCV과 같이 결과치들의 평균을 도출하여 사용합니다.



k-Fold Cross-Validation

장점 : 빠른 시간내에 결과 값 도출, LOOCV 결과 차이가 적음

단점 : 데이터셋을 섞음으로 변동성이 수반됨



R을 이용한 교차 유효성 검사

R을 사용해서 교차 유효성 검사를 해보겠습니다. 가장 많이 사용되는 k-Fold Cross-Validation 평가 방법을 사용해보겠습니다. 이를 위해 의사 결정 나무와 교차 검증 패키지를 먼저 설치하겠습니다. 데이터는 R에서 제공하는 iris 데이터를 사용하겠습니다. 


# 의사 결정 나무 패키지 설치

install.package("party")

library(party)


# 교차 검증 패키지 설치

install.package("cvTools")

library(cvTools)


# 데이터 선언

head(iris)

str(iris)


# 교차 검증 k값 선언 k=3

cross=cvfolds(nrow(iris),k=3)

str(cross)

cross


# 균등분할 데이터 셋

cross$which


# 랜덤하게 선정된 행번호

cross$subsets


# 3-Fold 교차검정

k=1:3


# 분류 정확도

acc=numeric()


# index

cnt=1


# k-Fold 교차검증

for(i in k){

data_index=cross$subsets[cross$which==i,1]


# 검정데이터 생성

test=iris[data_index,]


# 트레인 생성

formula=Species~.


# 훈련데이터 생성

train=iris[-data_index,]


# 의사결정나무

model=ctree(formula, data=train)

pred=predict(model, test)


# 정확도 측정

t=table(pred, test$Species)

print(t)

acc[cnt]=(t[1,1]+t[2,2]+t[3,3])/sum(t)

cnt=cnt+1}

acc

mean(acc)


댓글