Introduction

The objective of this analysis is to take a deep dive into the OJ dataset, which is part of the ISLR package. This dataset contains sales information for the Citrus Hill and Minute Maid brands of orange juice. After running some basic exploratory data analyses, we will try to predict which brand a customer purchases. Will do so by fitting decision trees to the data, with Purchase as the response and the other variables as predictors.

Data

The data contains 1070 purchases where the customer either purchased Citrus Hill or Minute Maid Orange Juice. A number of characteristics of the customer and product are recorded.

Attributes

A data frame with 1070 observations on the following 18 variables.

Purchase: A factor with levels CH and MM indicating whether the customer purchased Citrus Hill or Minute Maid Orange Juice

WeekofPurchase: Week of purchase

StoreID: Store ID

PriceCH: Price charged for CH

PriceMM: Price charged for MM

DiscCH: Discount offered for CH

DiscMM: Discount offered for MM

SpecialCH: Indicator of special on CH

SpecialMM: Indicator of special on MM

LoyalCH: Customer brand loyalty for CH

SalePriceMM: Sale price for MM

SalePriceCH: Sale price for CH

PriceDiff: Sale price of MM less sale price of CH

Store7: A factor with levels No and Yes indicating whether the sale is at Store 7

PctDiscMM: Percentage discount for MM

PctDiscCH: Percentage discount for CH

ListPriceDiff: List price of MM less list price of CH

STORE: Which of 5 possible stores the sale occured at

Source

Stine, Robert A., Foster, Dean P., Waterman, Richard P. Business Analysis Using Regression (1998). Published by Springer.

Exploratory Data Analysis

Let us first import the data, look at the structure and see if there are any missing values:

library(ISLR)
attach(OJ)
str(OJ)
## 'data.frame':    1070 obs. of  18 variables:
##  $ Purchase      : Factor w/ 2 levels "CH","MM": 1 1 1 2 1 1 1 1 1 1 ...
##  $ WeekofPurchase: num  237 239 245 227 228 230 232 234 235 238 ...
##  $ StoreID       : num  1 1 1 1 7 7 7 7 7 7 ...
##  $ PriceCH       : num  1.75 1.75 1.86 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceMM       : num  1.99 1.99 2.09 1.69 1.69 1.99 1.99 1.99 1.99 1.99 ...
##  $ DiscCH        : num  0 0 0.17 0 0 0 0 0 0 0 ...
##  $ DiscMM        : num  0 0.3 0 0 0 0 0.4 0.4 0.4 0.4 ...
##  $ SpecialCH     : num  0 0 0 0 0 0 1 1 0 0 ...
##  $ SpecialMM     : num  0 1 0 0 0 1 1 0 0 0 ...
##  $ LoyalCH       : num  0.5 0.6 0.68 0.4 0.957 ...
##  $ SalePriceMM   : num  1.99 1.69 2.09 1.69 1.69 1.99 1.59 1.59 1.59 1.59 ...
##  $ SalePriceCH   : num  1.75 1.75 1.69 1.69 1.69 1.69 1.69 1.75 1.75 1.75 ...
##  $ PriceDiff     : num  0.24 -0.06 0.4 0 0 0.3 -0.1 -0.16 -0.16 -0.16 ...
##  $ Store7        : Factor w/ 2 levels "No","Yes": 1 1 1 1 2 2 2 2 2 2 ...
##  $ PctDiscMM     : num  0 0.151 0 0 0 ...
##  $ PctDiscCH     : num  0 0 0.0914 0 0 ...
##  $ ListPriceDiff : num  0.24 0.24 0.23 0 0 0.3 0.3 0.24 0.24 0.24 ...
##  $ STORE         : num  1 1 1 1 0 0 0 0 0 0 ...
summary(OJ)
##  Purchase WeekofPurchase     StoreID        PriceCH         PriceMM     
##  CH:653   Min.   :227.0   Min.   :1.00   Min.   :1.690   Min.   :1.690  
##  MM:417   1st Qu.:240.0   1st Qu.:2.00   1st Qu.:1.790   1st Qu.:1.990  
##           Median :257.0   Median :3.00   Median :1.860   Median :2.090  
##           Mean   :254.4   Mean   :3.96   Mean   :1.867   Mean   :2.085  
##           3rd Qu.:268.0   3rd Qu.:7.00   3rd Qu.:1.990   3rd Qu.:2.180  
##           Max.   :278.0   Max.   :7.00   Max.   :2.090   Max.   :2.290  
##      DiscCH            DiscMM         SpecialCH        SpecialMM     
##  Min.   :0.00000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:0.00000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :0.00000   Median :0.0000   Median :0.0000   Median :0.0000  
##  Mean   :0.05186   Mean   :0.1234   Mean   :0.1477   Mean   :0.1617  
##  3rd Qu.:0.00000   3rd Qu.:0.2300   3rd Qu.:0.0000   3rd Qu.:0.0000  
##  Max.   :0.50000   Max.   :0.8000   Max.   :1.0000   Max.   :1.0000  
##     LoyalCH          SalePriceMM     SalePriceCH      PriceDiff      
##  Min.   :0.000011   Min.   :1.190   Min.   :1.390   Min.   :-0.6700  
##  1st Qu.:0.325257   1st Qu.:1.690   1st Qu.:1.750   1st Qu.: 0.0000  
##  Median :0.600000   Median :2.090   Median :1.860   Median : 0.2300  
##  Mean   :0.565782   Mean   :1.962   Mean   :1.816   Mean   : 0.1465  
##  3rd Qu.:0.850873   3rd Qu.:2.130   3rd Qu.:1.890   3rd Qu.: 0.3200  
##  Max.   :0.999947   Max.   :2.290   Max.   :2.090   Max.   : 0.6400  
##  Store7      PctDiscMM        PctDiscCH       ListPriceDiff  
##  No :714   Min.   :0.0000   Min.   :0.00000   Min.   :0.000  
##  Yes:356   1st Qu.:0.0000   1st Qu.:0.00000   1st Qu.:0.140  
##            Median :0.0000   Median :0.00000   Median :0.240  
##            Mean   :0.0593   Mean   :0.02731   Mean   :0.218  
##            3rd Qu.:0.1127   3rd Qu.:0.00000   3rd Qu.:0.300  
##            Max.   :0.4020   Max.   :0.25269   Max.   :0.440  
##      STORE      
##  Min.   :0.000  
##  1st Qu.:0.000  
##  Median :2.000  
##  Mean   :1.631  
##  3rd Qu.:3.000  
##  Max.   :4.000

No missing values and nothing out of the ordinary.

We could take a look at the relationships between the variables utilizing a correlation matrix, but since we have so many variables, we can go ahead and pull a sorted list of the variables with their correlation coefficient. We will first need to convert the factor variables of Purchase and Store7 to numerics so that we can utilize the correlation matrix function.

OJ.numeric = OJ
OJ.numeric$Purchase = as.numeric(OJ.numeric$Purchase)
OJ.numeric$Store7 = as.numeric(OJ.numeric$Store7)
cor.table=cor(OJ.numeric)
cor.table[lower.tri(cor.table,diag=TRUE)]=NA  #Prepare to drop duplicates and meaningless information
cor.table=as.data.frame(as.table(cor.table))  #Turn into a 3-column table
cor.table=na.omit(cor.table)  #Get rid of the junk we flagged above
cor.table=cor.table[order(-abs(cor.table$Freq)),]    #Sort by highest correlation (whether +ve or -ve)
head(cor.table, 15)
##               Var1          Var2       Freq
## 276         DiscCH     PctDiscCH  0.9990225
## 259         DiscMM     PctDiscMM  0.9987932
## 237        StoreID        Store7  0.9301615
## 263    SalePriceMM     PctDiscMM -0.8567490
## 227    SalePriceMM     PriceDiff  0.8527979
## 187         DiscMM   SalePriceMM -0.8468676
## 265      PriceDiff     PctDiscMM -0.8280972
## 223         DiscMM     PriceDiff -0.8239080
## 320         Store7         STORE -0.8054470
## 282    SalePriceCH     PctDiscCH -0.7227756
## 204         DiscCH   SalePriceCH -0.7112738
## 56  WeekofPurchase       PriceCH  0.7043241
## 293        PriceMM ListPriceDiff  0.6651870
## 163       Purchase       LoyalCH -0.6405824
## 76         PriceCH       PriceMM  0.6164017

We will only pick the top 15 as to limit the output. There appear to be a good number of strong relationships between the variables. Some of the relationships are obvious ones, for example, DiscCH and PctDiscCH are strongly correlated because those variables are both pertaining to discounts on each brand.

Analysis

Now let us begin building our decision tree model. Decision trees are a class of predictive data mining tools which predict either a categorical or continuous response variable. They get their name from the structure of the models built. A series of decisions are made to segment the data into homogeneous subgroups. This is also called recursive partitioning. When drawn out graphically, the model can resemble a tree with branches.

We will first split the data into a training set containing a random sample of 800 observations, and a test set containing the remaining observations.

set.seed(1013)

train = sample(dim(OJ)[1], 800)
OJ.train = OJ[train, ]
OJ.test = OJ[-train, ]

Now let us import the tree library to build our decision tree:

library(tree)
oj.tree = tree(Purchase ~ ., data = OJ.train)
summary(oj.tree)
## 
## Classification tree:
## tree(formula = Purchase ~ ., data = OJ.train)
## Variables actually used in tree construction:
## [1] "LoyalCH"   "PriceDiff"
## Number of terminal nodes:  7 
## Residual mean deviance:  0.7517 = 596.1 / 793 
## Misclassification error rate: 0.155 = 124 / 800

We see that only two variables - LoyalCH and PriceDiff - were used in the tree construction. There are seven terminal nodes with a training error rate of 0.155 or 15.5%.

oj.tree
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 800 1075.00 CH ( 0.60250 0.39750 )  
##    2) LoyalCH < 0.5036 359  422.80 MM ( 0.27577 0.72423 )  
##      4) LoyalCH < 0.276142 170  119.10 MM ( 0.11176 0.88824 ) *
##      5) LoyalCH > 0.276142 189  257.50 MM ( 0.42328 0.57672 )  
##       10) PriceDiff < 0.05 79   76.79 MM ( 0.18987 0.81013 ) *
##       11) PriceDiff > 0.05 110  148.80 CH ( 0.59091 0.40909 ) *
##    3) LoyalCH > 0.5036 441  343.30 CH ( 0.86848 0.13152 )  
##      6) LoyalCH < 0.764572 186  210.30 CH ( 0.74731 0.25269 )  
##       12) PriceDiff < -0.165 29   34.16 MM ( 0.27586 0.72414 ) *
##       13) PriceDiff > -0.165 157  140.90 CH ( 0.83439 0.16561 )  
##         26) PriceDiff < 0.265 82   95.37 CH ( 0.73171 0.26829 ) *
##         27) PriceDiff > 0.265 75   31.23 CH ( 0.94667 0.05333 ) *
##      7) LoyalCH > 0.764572 255   90.67 CH ( 0.95686 0.04314 ) *

Let us take a deeper dive into the tree and look at one of the terminal nodes. Let’s pick terminal node labeled 11. The splitting variable at this node is PriceDiff. The splitting value of this node is 0.05. There are 110 points in the subtree below this node. The deviance for all points contained in region below this node is 148.80. A * in the line denotes that this is in fact a terminal node. The prediction at this node is Sales = CH. About 59.1% points in this node have MM as value of Sales. Remaining 40.9% points have CH as value of Sales.

Now lets plot the tree:

plot(oj.tree)
text(oj.tree, pretty = 0)

LoyalCH is the most important variable of the tree, in fact the top 3 nodes contain LoyalCH. If LoyalCH< 0.27, the tree predicts MM. If LoyalCH > 0.76, the tree predicts CH. For intermediate values of LoyalCH, the decision also depends on the value of PriceDiff.

Now lets predict the response on the test data, and produce a confusion matrix comparing the test labels to the predicted test labels.

oj.pred = predict(oj.tree, OJ.test, type = "class")
table(OJ.test$Purchase, oj.pred)
##     oj.pred
##       CH  MM
##   CH 152  19
##   MM  32  67
(152+67)/270
## [1] 0.8111111

We obtain a classification rate of 81.1% or in other words, a test error rate of 18.9%.

Let’s apply the cv.tree() function to the training set in order to determine the optimal tree size.

cv.oj = cv.tree(oj.tree, FUN = prune.tree)
cv.oj
## $size
## [1] 7 6 5 4 3 2 1
## 
## $dev
## [1]  697.7742  672.3676  721.1061  722.5296  785.2427  786.4040 1078.4624
## 
## $k
## [1]      -Inf  14.33140  31.91234  35.17952  42.37864  46.23075 309.00727
## 
## $method
## [1] "deviance"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

We see that the optimal tree size is six terminal nodes with a cross-validation error of 672.3676.

Cross-validation plot of the error rate as a function of both tree size and k:

par(mfrow=c(1,2))
plot(cv.oj$size, cv.oj$dev, type = "b", xlab = "Tree Size", ylab = "Deviance")
plot(cv.oj$size, cv.oj$k, type = "b", xlab = "Cost-Complexity Parameter", ylab = "Deviance")

A tree size of 6 corresponds to the lowest cross-validation error.

oj.pruned = prune.tree(oj.tree, best = 6)
summary(oj.pruned)
## 
## Classification tree:
## snip.tree(tree = oj.tree, nodes = 13L)
## Variables actually used in tree construction:
## [1] "LoyalCH"   "PriceDiff"
## Number of terminal nodes:  6 
## Residual mean deviance:  0.7689 = 610.5 / 794 
## Misclassification error rate: 0.155 = 124 / 800

The training misclassification error rate is 0.155 or 15.5% which is exactly the same as the error rate for the full original tree.

plot(oj.pruned)
text(oj.pruned, pretty = 0)

Now how do the test errors for the unpruned and pruned trees compare:

pred.unpruned = predict(oj.tree, OJ.test, type = "class")
misclass.unpruned = sum(OJ.test$Purchase != pred.unpruned)
misclass.unpruned/length(pred.unpruned)
## [1] 0.1888889
pred.pruned = predict(oj.pruned, OJ.test, type = "class")
misclass.pruned = sum(OJ.test$Purchase != pred.pruned)
misclass.pruned/length(pred.pruned)
## [1] 0.1888889

We see that the test error rates between the unpruned and pruned trees are exactly the same at 18.89%.

Conclusion and Summary

We first fit a full decision tree on the OJ data set using Purchase (A factor with levels CH and MM indicating whether the customer purchased Citrus Hill or Minute Maid Orange Juice) as the response and utilizing the remaining 17 variables as the predictors. The test error rate for the full tree was 18.89%. Then we performed pruning using cross-vaildation to obtain a six terminal tree that had the lowest cross-validation error rate. When comparing the two trees, they both had the exact same test error rate. Alhough in most situations when comparing an unpruned to a pruned tree the error rates are not usually the same. This could be due to random chance of the sampling we did.