The objective of this analysis is to take a deep dive into the OJ dataset, which is part of the ISLR package. This dataset contains sales information for the Citrus Hill and Minute Maid brands of orange juice. After running some basic exploratory data analyses, we will try to predict which brand a customer purchases. Will do so by fitting decision trees to the data, with Purchase as the response and the other variables as predictors.
The data contains 1070 purchases where the customer either purchased Citrus Hill or Minute Maid Orange Juice. A number of characteristics of the customer and product are recorded.
A data frame with 1070 observations on the following 18 variables.
Purchase: A factor with levels CH and MM indicating whether the customer purchased Citrus Hill or Minute Maid Orange Juice
WeekofPurchase: Week of purchase
StoreID: Store ID
PriceCH: Price charged for CH
PriceMM: Price charged for MM
DiscCH: Discount offered for CH
DiscMM: Discount offered for MM
SpecialCH: Indicator of special on CH
SpecialMM: Indicator of special on MM
LoyalCH: Customer brand loyalty for CH
SalePriceMM: Sale price for MM
SalePriceCH: Sale price for CH
PriceDiff: Sale price of MM less sale price of CH
Store7: A factor with levels No and Yes indicating whether the sale is at Store 7
PctDiscMM: Percentage discount for MM
PctDiscCH: Percentage discount for CH
ListPriceDiff: List price of MM less list price of CH
STORE: Which of 5 possible stores the sale occured at
Stine, Robert A., Foster, Dean P., Waterman, Richard P. Business Analysis Using Regression (1998). Published by Springer.
Let us first import the data, look at the structure and see if there are any missing values:
library(ISLR)
attach(OJ)
str(OJ)
## 'data.frame': 1070 obs. of 18 variables:
## $ Purchase : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
## $ WeekofPurchase: num 237 239 245 227 228 230 232 234 235 238 ...
## $ StoreID : num 1 1 1 1 7 7 7 7 7 7 ...
## $ PriceCH : num 1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceMM : num 1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
## $ DiscCH : num 0 0 0.17 0 0 0 0 0 0 0 ...
## $ DiscMM : num 0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
## $ SpecialCH : num 0 0 0 0 0 0 1 1 0 0 ...
## $ SpecialMM : num 0 1 0 0 0 1 1 0 0 0 ...
## $ LoyalCH : num 0.5 0.6 0.68 0.4 0.957 ...
## $ SalePriceMM : num 1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
## $ SalePriceCH : num 1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
## $ PriceDiff : num 0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
## $ Store7 : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
## $ PctDiscMM : num 0 0.151 0 0 0 ...
## $ PctDiscCH : num 0 0 0.0914 0 0 ...
## $ ListPriceDiff : num 0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
## $ STORE : num 1 1 1 1 0 0 0 0 0 0 ...
summary(OJ)
## Purchase WeekofPurchase StoreID PriceCH PriceMM
## CH:653 Min. :227.0 Min. :1.00 Min. :1.690 Min. :1.690
## MM:417 1st Qu.:240.0 1st Qu.:2.00 1st Qu.:1.790 1st Qu.:1.990
## Median :257.0 Median :3.00 Median :1.860 Median :2.090
## Mean :254.4 Mean :3.96 Mean :1.867 Mean :2.085
## 3rd Qu.:268.0 3rd Qu.:7.00 3rd Qu.:1.990 3rd Qu.:2.180
## Max. :278.0 Max. :7.00 Max. :2.090 Max. :2.290
## DiscCH DiscMM SpecialCH SpecialMM
## Min. :0.00000 Min. :0.0000 Min. :0.0000 Min. :0.0000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.0000 1st Qu.:0.0000
## Median :0.00000 Median :0.0000 Median :0.0000 Median :0.0000
## Mean :0.05186 Mean :0.1234 Mean :0.1477 Mean :0.1617
## 3rd Qu.:0.00000 3rd Qu.:0.2300 3rd Qu.:0.0000 3rd Qu.:0.0000
## Max. :0.50000 Max. :0.8000 Max. :1.0000 Max. :1.0000
## LoyalCH SalePriceMM SalePriceCH PriceDiff
## Min. :0.000011 Min. :1.190 Min. :1.390 Min. :-0.6700
## 1st Qu.:0.325257 1st Qu.:1.690 1st Qu.:1.750 1st Qu.: 0.0000
## Median :0.600000 Median :2.090 Median :1.860 Median : 0.2300
## Mean :0.565782 Mean :1.962 Mean :1.816 Mean : 0.1465
## 3rd Qu.:0.850873 3rd Qu.:2.130 3rd Qu.:1.890 3rd Qu.: 0.3200
## Max. :0.999947 Max. :2.290 Max. :2.090 Max. : 0.6400
## Store7 PctDiscMM PctDiscCH ListPriceDiff
## No :714 Min. :0.0000 Min. :0.00000 Min. :0.000
## Yes:356 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.140
## Median :0.0000 Median :0.00000 Median :0.240
## Mean :0.0593 Mean :0.02731 Mean :0.218
## 3rd Qu.:0.1127 3rd Qu.:0.00000 3rd Qu.:0.300
## Max. :0.4020 Max. :0.25269 Max. :0.440
## STORE
## Min. :0.000
## 1st Qu.:0.000
## Median :2.000
## Mean :1.631
## 3rd Qu.:3.000
## Max. :4.000
No missing values and nothing out of the ordinary.
We could take a look at the relationships between the variables utilizing a correlation matrix, but since we have so many variables, we can go ahead and pull a sorted list of the variables with their correlation coefficient. We will first need to convert the factor variables of Purchase and Store7 to numerics so that we can utilize the correlation matrix function.
OJ.numeric = OJ
OJ.numeric$Purchase = as.numeric(OJ.numeric$Purchase)
OJ.numeric$Store7 = as.numeric(OJ.numeric$Store7)
cor.table=cor(OJ.numeric)
cor.table[lower.tri(cor.table,diag=TRUE)]=NA #Prepare to drop duplicates and meaningless information
cor.table=as.data.frame(as.table(cor.table)) #Turn into a 3-column table
cor.table=na.omit(cor.table) #Get rid of the junk we flagged above
cor.table=cor.table[order(-abs(cor.table$Freq)),] #Sort by highest correlation (whether +ve or -ve)
head(cor.table, 15)
## Var1 Var2 Freq
## 276 DiscCH PctDiscCH 0.9990225
## 259 DiscMM PctDiscMM 0.9987932
## 237 StoreID Store7 0.9301615
## 263 SalePriceMM PctDiscMM -0.8567490
## 227 SalePriceMM PriceDiff 0.8527979
## 187 DiscMM SalePriceMM -0.8468676
## 265 PriceDiff PctDiscMM -0.8280972
## 223 DiscMM PriceDiff -0.8239080
## 320 Store7 STORE -0.8054470
## 282 SalePriceCH PctDiscCH -0.7227756
## 204 DiscCH SalePriceCH -0.7112738
## 56 WeekofPurchase PriceCH 0.7043241
## 293 PriceMM ListPriceDiff 0.6651870
## 163 Purchase LoyalCH -0.6405824
## 76 PriceCH PriceMM 0.6164017
We will only pick the top 15 as to limit the output. There appear to be a good number of strong relationships between the variables. Some of the relationships are obvious ones, for example, DiscCH and PctDiscCH are strongly correlated because those variables are both pertaining to discounts on each brand.
Now let us begin building our decision tree model. Decision trees are a class of predictive data mining tools which predict either a categorical or continuous response variable. They get their name from the structure of the models built. A series of decisions are made to segment the data into homogeneous subgroups. This is also called recursive partitioning. When drawn out graphically, the model can resemble a tree with branches.
We will first split the data into a training set containing a random sample of 800 observations, and a test set containing the remaining observations.
set.seed(1013)
train = sample(dim(OJ)[1], 800)
OJ.train = OJ[train, ]
OJ.test = OJ[-train, ]
Now let us import the tree library to build our decision tree:
library(tree)
oj.tree = tree(Purchase ~ ., data = OJ.train)
summary(oj.tree)
##
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 7
## Residual mean deviance: 0.7517 = 596.1 / 793
## Misclassification error rate: 0.155 = 124 / 800
We see that only two variables - LoyalCH and PriceDiff - were used in the tree construction. There are seven terminal nodes with a training error rate of 0.155 or 15.5%.
oj.tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 800 1075.00 CH ( 0.60250 0.39750 )
## 2) LoyalCH < 0.5036 359 422.80 MM ( 0.27577 0.72423 )
## 4) LoyalCH < 0.276142 170 119.10 MM ( 0.11176 0.88824 ) *
## 5) LoyalCH > 0.276142 189 257.50 MM ( 0.42328 0.57672 )
## 10) PriceDiff < 0.05 79 76.79 MM ( 0.18987 0.81013 ) *
## 11) PriceDiff > 0.05 110 148.80 CH ( 0.59091 0.40909 ) *
## 3) LoyalCH > 0.5036 441 343.30 CH ( 0.86848 0.13152 )
## 6) LoyalCH < 0.764572 186 210.30 CH ( 0.74731 0.25269 )
## 12) PriceDiff < -0.165 29 34.16 MM ( 0.27586 0.72414 ) *
## 13) PriceDiff > -0.165 157 140.90 CH ( 0.83439 0.16561 )
## 26) PriceDiff < 0.265 82 95.37 CH ( 0.73171 0.26829 ) *
## 27) PriceDiff > 0.265 75 31.23 CH ( 0.94667 0.05333 ) *
## 7) LoyalCH > 0.764572 255 90.67 CH ( 0.95686 0.04314 ) *
Let us take a deeper dive into the tree and look at one of the terminal nodes. Let’s pick terminal node labeled 11. The splitting variable at this node is PriceDiff. The splitting value of this node is 0.05. There are 110 points in the subtree below this node. The deviance for all points contained in region below this node is 148.80. A * in the line denotes that this is in fact a terminal node. The prediction at this node is Sales = CH. About 59.1% points in this node have MM as value of Sales. Remaining 40.9% points have CH as value of Sales.
Now lets plot the tree:
plot(oj.tree)
text(oj.tree, pretty = 0)
LoyalCH is the most important variable of the tree, in fact the top 3 nodes contain LoyalCH. If LoyalCH< 0.27, the tree predicts MM. If LoyalCH > 0.76, the tree predicts CH. For intermediate values of LoyalCH, the decision also depends on the value of PriceDiff.
Now lets predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels.
oj.pred = predict(oj.tree, OJ.test, type = "class")
table(OJ.test$Purchase, oj.pred)
## oj.pred
## CH MM
## CH 152 19
## MM 32 67
(152+67)/270
## [1] 0.8111111
We obtain a classification rate of 81.1% or in other words, a test error rate of 18.9%.
Let’s apply the cv.tree() function to the training set in order to determine the optimal tree size.
cv.oj = cv.tree(oj.tree, FUN = prune.tree)
cv.oj
## $size
## [1] 7 6 5 4 3 2 1
##
## $dev
## [1] 697.7742 672.3676 721.1061 722.5296 785.2427 786.4040 1078.4624
##
## $k
## [1] -Inf 14.33140 31.91234 35.17952 42.37864 46.23075 309.00727
##
## $method
## [1] "deviance"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
We see that the optimal tree size is six terminal nodes with a cross-validation error of 672.3676.
Cross-validation plot of the error rate as a function of both tree size and k:
par(mfrow=c(1,2))
plot(cv.oj$size, cv.oj$dev, type = "b", xlab = "Tree Size", ylab = "Deviance")
plot(cv.oj$size, cv.oj$k, type = "b", xlab = "Cost-Complexity Parameter", ylab = "Deviance")
A tree size of 6 corresponds to the lowest cross-validation error.
oj.pruned = prune.tree(oj.tree, best = 6)
summary(oj.pruned)
##
## Classification tree:
## snip.tree(tree = oj.tree, nodes = 13L)
## Variables actually used in tree construction:
## [1] "LoyalCH" "PriceDiff"
## Number of terminal nodes: 6
## Residual mean deviance: 0.7689 = 610.5 / 794
## Misclassification error rate: 0.155 = 124 / 800
The training misclassification error rate is 0.155 or 15.5% which is exactly the same as the error rate for the full original tree.
plot(oj.pruned)
text(oj.pruned, pretty = 0)
Now how do the test errors for the unpruned and pruned trees compare:
pred.unpruned = predict(oj.tree, OJ.test, type = "class")
misclass.unpruned = sum(OJ.test$Purchase != pred.unpruned)
misclass.unpruned/length(pred.unpruned)
## [1] 0.1888889
pred.pruned = predict(oj.pruned, OJ.test, type = "class")
misclass.pruned = sum(OJ.test$Purchase != pred.pruned)
misclass.pruned/length(pred.pruned)
## [1] 0.1888889
We see that the test error rates between the unpruned and pruned trees are exactly the same at 18.89%.
We first fit a full decision tree on the OJ data set using Purchase (A factor with levels CH and MM indicating whether the customer purchased Citrus Hill or Minute Maid Orange Juice) as the response and utilizing the remaining 17 variables as the predictors. The test error rate for the full tree was 18.89%. Then we performed pruning using cross-vaildation to obtain a six terminal tree that had the lowest cross-validation error rate. When comparing the two trees, they both had the exact same test error rate. Alhough in most situations when comparing an unpruned to a pruned tree the error rates are not usually the same. This could be due to random chance of the sampling we did.