Week 03: Predicting with Trees


이론적인 내용은 이전 포스트를 참조 한다.


Tree 방식의 Classification의 장단점은 아래와 같다.

Pros:

쉬운 이해

비선형 데이터에 대한 좋은 성능

Cons:

pruning이나 cross-validation을 하지 않으면 over-fitting의 문제가 발생함.

linear regression model같은 것들보다 불확실한 데이터 즉 트레이닝에서 발견되지 않았던 많이 다른 데이터를

추정하기아 어렵다.

결과가가 항상 특정 tree를 구성하는 node 값에의해서 결정되므로 트리에 구조에따라서 결과가 자주 변한다.



예제는 아래와 같다.

2010년 미국 대통령 선거당시 힐러리 클린턴과 버락 오바에 대한 것이다.

각각의 질의에 따라 어떻게 투표 결과가 이뤄 질지에대한 예측을 나타내는 트리이다.



예제: Iris Data


해당 코드는 caret package를 이용하며 해당 package의 party, rpart 방식을 이용해서 tree를 구성한 것이다.

다른 package로는 tree package가 존재 한다.


데이터를 불러온다.

data(iris); library(ggplot2)
names(iris)
[1] "Sepal.Length" "Sepal.Width"  "Petal.Length" "Petal.Width"  "Species"

각 테이블의 양을 파악한다

table(iris$Species)
    setosa versicolor  virginica 
        50         50         50 

트레이닝 데이터와 테스트 데이터를 분할 한다.

inTrain <- createDataPartition(y=iris$Species,
                               p=0.7, list=FALSE)
training <- iris[inTrain,]
testing <- iris[-inTrain,]
dim(training); dim(testing)
[1] 251   5
[1] 33  5

그래프를 그려서 세개의 분포를 확인해 보자.

qplot(Petal.Width,Sepal.Width,colour=Species,data=training)

아래와 같이 세개의 종들이 나름 극명하게 나뉘는 것을 알 수 있다. 


이제 caret package를 이용해서 모델을 학습해서 수행해 보자.

~,의 의미는 모든 variables를 species를 prediction하기위해서 사용하겠다는 것이다.

library(caret)
modFit <- train(Species ~ .,method="rpart",data=training)
print(modFit$finalModel)
n= 117 
node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 117 76 virginica (0.34188034 0.30769231 0.35042735)  
  2) Petal.Length< 2.45 40  0 setosa (1.00000000 0.00000000 0.00000000) *
  3) Petal.Length>=2.45 77 36 virginica (0.00000000 0.46753247 0.53246753)  
    6) Petal.Width< 1.75 40  5 versicolor (0.00000000 0.87500000 0.12500000) *
    7) Petal.Width>=1.75 37  1 virginica (0.00000000 0.02702703 0.97297297) *


모델을 그래프로 표현한다. 좀 더 명확히 보기위해서 rattle package를 이용해서 그린다.

아래의 그림은 일반 plot()이용해서 그린 그래프 이다.

library(rattle)
fancyRpartPlot(modFit$finalModel)
#draw plot using nomal plot
plot(modFit$finalModel, uniform=TRUE, 
     main="Classification Tree")
text(modFit$finalModel, use.n=TRUE, all=TRUE, cex=.8)



fancyRpartPlot으로 그리면 아래와 같다.



normal plot으로 그리면 아래와 같다.


생성된 tree 모델을 이용해서 예측을 수행 한다.

predict(modFit,newdata=testing)
 [1] setosa     setosa     setosa     setosa     setosa    
 [6] setosa     setosa     setosa     setosa     setosa    
[11] versicolor versicolor versicolor versicolor versicolor
[16] versicolor versicolor versicolor versicolor versicolor
[21] versicolor versicolor versicolor versicolor virginica 
[26] virginica  virginica  virginica  virginica  virginica 
[31] virginica  virginica  virginica 
Levels: setosa versicolor virginica



+ Recent posts