11.9 - R Scripts11.9 - R Scripts
1. Acquire Data
The diabetes data set is taken from the UCI machine learning database on Kaggle: Pima Indians Diabetes Database
- 768 samples in the dataset
- 8 quantitative variables
- 2 classes; with or without signs of diabetes
Load data into R as follows:
# set the working directory setwd("C:/STAT 897D data mining") # comma delimited data and no header for each variable RawData = read.table("diabetes.data",sep = ",",header=FALSE)
In RawData, the response variable is its last column; and the remaining columns are the predictor variables.
responseY = as.matrix(RawData[,dim(RawData)]) predictorX = as.matrix(RawData[,1:(dim(RawData)-1)]) data.train = as.data.frame(cbind(responseY, predictorX)) names(data.train) = c("Y", "X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8")
2. Classification and Regression Trees
The generation of a tree comprises two steps: to grow a tree and to prune a tree.
2.1 Grow a Tree
In R, the tree library can be used to construct classification and regression trees (see R Lab 8). As an alternative, they can also be generated through the rpart library package and the rpart(formula) function grows a tree of the data. For the argument method, rpart(formula, method="class") specifies the response is a categorical variable, otherwise rpart(formula, method="anova") is assumed for a continuous response.
library(rpart) set.seed(19) model.tree <- rpart(Y ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8, data.train, method="class")
To plot the tree, the following code can be executed in R: plot(result, uniform=TRUE) plots the tree nodes, which are vertically equally spaced. Then text(result, use.n=TRUE) writes the decision equation at each node.
plot(model.tree, uniform=T) text(model.tree, use.n=T)
In Figure 1, the plot shows the predictor and its threshold used at each node of the tree, and it shows the number of observations in each class at each terminal node. Specifically, the numbers of points in Class 0 and Class 1 are displayed as ·/·.
At each node, we move down the left branch if the decision statement is true and down the right branch if it is false. The functions in the rpart library draw the tree in such a way that left branches are constrained to have a higher proportion of 0 values for the response variable than right branches. So, some of the decision statements contain "less than" (<) symbols and some contain "greater than or equal to" (>=) symbols (whatever is needed to satisfy this constraint). By contrast, the functions in the tree library draw the tree in such a way that all the decision statements contain "less than" (<) symbols. Thus, either branch may have a higher proportion of 0 values for the response value than the other.
2.2 Prune a Tree
To obtain the right sized tree to avoid overfitting, the cptable element of the result generated by rpart can be extracted.
The results are shown as follows:
The cptable provides a brief summary of the overall fit of the model. The table is printed from the smallest tree (no splits) to the largest tree. The “CP” column lists the values of the complexity parameter, the number of splits is listed under”nsplit”, and the column ”xerror” contains cross-validated classification error rates; the standard deviation of the cross-validation error rates are in the ”xstd” column. Normally, we select a tree size that minimizes the cross-validated error, which is shown in the “xerror” column printed by ()\$cptable.
Selection of the optimal subtree can also be done automatically using the following code:
opt <- model.tree\$cptable[which.min(model.tree$cptable[,"xerror"]),"CP"]
opt stores the optimal complexity parameter. Now, to prune a tree with the complexity parameter chosen, simply do the following. The pruning is performed by function prune, which takes the full tree as the first argument and the chosen complexity parameter as the second.
model.ptree <- prune(model.tree, cp = opt)
The pruned tree is shown in Figure 2 using the same plotting functions for creating Figure 1.
Further information on the pruned tree can be accessed using the summary() function.