# Run this code block to load the Tidyverse package
.libPaths(new = "~/Rlibs")
library(tidyverse)
library(modelr)
library(readxl)
library(rpart)
library(stats)
library(randomForest)
library(rattle)
library(rpart.plot)
library(RColorBrewer)
library(caret)
# To change the size of any plots, copy the code snippet
# below, uncomment it, and set the size of the width
# and height.
# Note: All subsequent figures will use the same size,
# unless you change the options() snippet and run it
# again.
# options(repr.plot.width=6, repr.plot.height=4)
For the final lab for the semester, we will work through a tutorial-style Jupyter notebook that gives a "real-world" demonstration of how to use machine learning to model a dataset, use cross-validation to check it, and from there make predictions. We employ a "black box" approach, and make use of useful R packages that handle much of the hard work for us. In essence, you will see that you do not need to have a deep and sophisticated knowledge of the models and machine learning algorithms in order to do useful work!
First, we need to install two more packages into our SageMathCloud environment. Run the command below to install them (you only need to do this once). Note: This may take a couple of minutes to finish running.
# Install packages that you don't currently have
install.packages(c("rattle", "rpart.plot"), lib = c("~/Rlibs"))
We will use the Titanic dataset to demonstrate the power and flexibility of the machine learning packages available within R. You may remember working through this dataset for Homework 1 in CDS-101. If not, no worries, we recap the basic details of the dataset here.
This dataset contains information about passengers on the Titanic, the British passenger liner that crashed into an iceberg during its maiden voyage and sank early in the morning on April 15, 1912. The tragedy stands out as one of the deadliest commericial maritime disasters during peacetime in history. More than half of the passengers and crew died, due in large part to poor safety standards, such as not having enough lifeboats or not ensuring all lifeboats were filled to capacity during evacuation.
This dataset presents the most up-to-date knowledge about the passengers that were on the Titanic, including whether or not they survived. We will be using the dataset to build models that try to predict whether or not a passenger survived or died.
The columns in the dataset are:
Variable | Description |
---|---|
passengerid | A unique index used to uniquely identify each passenger in the dataset |
pclass | Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd) |
survival | Survival (0 = No; 1 = Yes) |
name | Name |
sex | Sex |
age | Age |
sibsp | Number of Siblings/Spouses Aboard |
parch | Number of Parents/Children Aboard |
ticket | Ticket Number |
fare | Passenger Fare (British pound) |
cabin | Cabin |
embarked | Port of Embarkation (C = Cherbourg; Q = Queenstown; S = Southampton) |
Also note that the following definitions were used for sibsp
and parch
:
Relative | Description |
---|---|
Sibling: | Brother, Sister, Stepbrother, or Stepsister of Passenger Aboard Titanic |
Spouse: | Husband or Wife of Passenger Aboard Titanic (Mistresses and Fiances Ignored) |
Parent: | Mother or Father of Passenger Aboard Titanic |
Child: | Son, Daughter, Stepson, or Stepdaughter of Passenger Aboard Titanic |
The dataset is saved as an Excel file in the xlsx format. We can use the read_excel()
function that is part of the readxl
library (loaded at the top of the page), to
extract all the relevant information. For convenience, we also redefine several columns as the factor
data type instead of the character
data type.
# Load and prepare dataset
dataset <- suppressMessages(read_excel("titanic3.xlsx"))
dataset <- dataset %>%
mutate(
passengerid = as.integer(passengerid),
pclass = as.factor(pclass),
survived = as.factor(survived),
name = as.factor(name),
sex = as.factor(sex),
sibsp = as.factor(sibsp),
parch = as.factor(parch),
ticket = as.factor(ticket),
fare = as.double(fare),
cabin = as.factor(cabin),
embarked = as.factor(embarked)
) %>%
select(passengerid, survived, pclass, name, everything())
We check the dataset to make sure that everything loaded correctly.
glimpse(dataset)
The following key questions will drive our analysis,
During the last 3 weeks, we have worked with linear models. For the labs, we only considered one-term models while in CDS-101 we also explored multi-term models as well. These are all examples of the standard class of linear models, which have the general form $f(x_1, x_2, \ldots, x_n) = a_1 x_1 + a_2 x_2 + \ldots + a_n x_n$. Note that, if you use transformations as we did in labs 12 and 13, this amounts to changing the fitting parameters $a_1, a_2, \ldots, a_n$.
The models we've used have assumed that the response variable (survived
in this dataset) is continuous. However, for this model, it is not,
it is categorical and binary (a passenger either survives or doesn't). The lm()
function isn't built to handle this all that well, so we
use the glm()
function instead. The initials mean Generalized Linear Model. These, as the name implies, work in more situations than lm()
,
including when the response vector isn't continuous.
The code below sets up a calculation using glm()
, and models the data using the passengers ages only. Run it, then print out the model to see
what you find.
dataset.model.glm.malefemale <- glm(survived ~ age, data = dataset, family = binomial(link = "logit"))
print(dataset.model.glm.malefemale)
Question
Does the output of the glm()
model resemble lm()
's output, or not?
It does somewhat resemble the lm output. I remember the lm output having more information and no AIC value.
Next, let's visualize the model and its predictions. When using glm()
, you need to
be careful with how you make your plot, so the code is provided for you below.
options(repr.plot.width=6, repr.plot.height=4)
grid.glm.malefemale <- dataset %>% data_grid(age) %>% mutate(pred = predict(dataset.model.glm.malefemale, newdata = ., type = 'response'))
ggplot(grid.glm.malefemale) +
geom_point(
data = mutate(dataset, survived = as.numeric(levels(survived))[survived]),
mapping = aes(x = age, y = survived)) +
geom_line(mapping = aes(x = age, y = pred), color = "red", size = 2)
Question
Compare the output values from the model (in red) to the values in the dataset (in black).
What is the model doing that the data is not?
How would you interpret this?
Also, speculate on a way that you could use the red curve to predict whether or not a passenger survives.
The model suggests that older people tend to have lower survivability rates. However the model suggests the means of surivorship which is unusual because survivorship is a binary value (either you survive or you don't) but the model suggests a halfway value.
You may interpret this as a probability of the person surviving. So if the 'survived' value on the model is less than .5 that means that person has a higher chance of perishing than surviving. So if the red line is closer to .25 that means that person probably has a 1/4 chance of surviving but the closer the line is to .5 that means that it is a 1/2 chance that person will survive (so really it is left up to chance).
Another method for building predictions are decision trees.
Decision trees are a lot like the IF()
statement we used in Google Sheets
on Lab Week 3, in that they have several nodes where you ask a if/then style
question, and if you answer "yes" then you go down the yes path and if you
answer "no" then you go down the no path. In fact, if you wanted to, you
could construct a rudimentary decision tree by hand by nesting several IF()
functions.
However, we will not be attempting to build a decision tree program ourselves,
as there are several convenient packages available that will build decision
trees for you. We use the version that comes as part of R in the rpart
function.
The code below builds a decision tree based on whether a passenger is male or female. Run the code to create the model and output a nice graphic that illustrates it's meaning.
options(repr.plot.width=6, repr.plot.height=6)
fit <- rpart(survived ~ sex,
data=dataset,
method="class")
prediction <- predict(fit, dataset, type = "class")
fancyRpartPlot(fit)
The above graphic is read from the top down. It starts by assuming all people perish (the zero at the top), with 62% of the dataset matching this assumption (left decimal) and 38% of the dataset not matching this assumption (right decimal). 100% of the passengers start here during the sorting procedure. The tree then asks a question, "Are you male?" If the passenger data says yes, then the data point is put in the left bucket. If the data says no, then they go to the right. The left bucket predicts the passenger did not survive. 64% of the data is sorted into this bucket. 81% of the male passengers sorted here indeed perished, while 19% did not. In the right bucket, the remaining 36% of passengers, the female passengers, are predicted to survive. 73% of the women survive, 27% do not. For a simple sort, this isn't so bad a prediction model!
Now let's have you try a decision tree model. Copy the code from above, keep the sex
variable, and add two more variables into it. You should choose from the list:
options(repr.plot.width=6, repr.plot.height=6)
fit <- rpart(survived ~ sex + pclass + fare,
data=dataset,
method="class")
prediction <- predict(fit, dataset, type = "class")
fancyRpartPlot(fit)
Question
How well does your decision tree seem to perform? Just from looking, does it do better or worse than the above model?
I suppose this model indicates which factors will affect survivorship more. It breaks down the demographics of passengers by probability and percentage of a certain demographic for survivorship. For the females 1/3 of them persihed but 73% did not. For the passenger class equaling 3 and were female, 17% of the data fell into this class while it was nearly a 50:50 chance of passengers in this class would survive while passengers that were in 1st or 2nd class and female, most of them survived. This decision tree is easier to read than a histogram where there is more data as opposed to a decision tree where you can just follow the decision tree to find a demographic of interest and it gives you how many of this demographic survived and how many did not and the percentage of all the data that makes up this demographic.
Using single decision trees seems very nice because its behavior is easy to interpret, but it turns out that any given tree may not be the most accurate model. The rules for splitting the variables into different branches are not absolute, so in principle you could end up with many different types of decision trees. Instead of trying to iterate through many possible examples, we move on to a more systematic approach to our modeling.
The above examples were meant to illustrate the concepts of generalized linear models and decision trees. While
we could attempt to add and remove variables as a way to manually optimize the model's predictive powers, there
are better, more powerful ways to address this problem using an ensemble version of decision trees called Random Forest
and tools such as k-fold cross-validation. For this part, we will utilize the R package Caret
, which standardizes
a lot of the modeling syntax and automates a lot of the work of training and testing models of different classes.
Max Kuhn, the developer of Caret
, describes the package as follows:
From the Caret
introduction page:
The
caret
package (short for _C_lassification _A_nd _RE_gression _T_raining) is a set of functions that attempt to streamline the process for creating predictive models. The package contains tools for:
- data splitting
- pre-processing
- feature selection
- model tuning using resampling
- variable importance estimation
- as well as other functionality.
There are many different modeling functions in R. Some have different syntax for model training and/or prediction. The package started off as a way to provide a uniform interface the functions themselves, as well as a way to standardize common tasks (such parameter tuning and variable importance).
In the long term, if you decide to use of machine learning in your future work, I wouldn't recommend only
using something like Caret
and never learning more. It's worth digging in to how all of this stuff works,
and will lead to a deeper appreciation and understanding of what's going on. But, for now, we use it
for sake of convenience and to show that you can still do useful things even as you're learning about
the various models that are available.
In order to use the Caret framework, we need to clean up our dataset a little more.
As we're going to use it, Caret
requires that there are no NA
values in the columns we use when creating a model.
The age
, embarked
, and fare
columns all contain NA
values. However, we do not want to just drop these from our dataset, as this will remove significant amounts of information.
Instead, we handle each as follows.
For ages, we can either take the median or mean age from the age
column, or we can use a decision tree created using rpart
to assign ages to passengers.
The decision tree route will be a little more robust, so let's use that. Run the following code to train the decision tree on the dataset to predict ages.
# Handle missing data (necessary for random forest)
Agefit <- rpart(age ~ pclass + sex + sibsp + parch + fare + embarked,
data=dataset[!is.na(dataset$age),],
method="anova")
To predict the ages for the missing cells, we simply run:
dataset$age[is.na(dataset$age)] <- predict(Agefit, dataset[is.na(dataset$age),])
For the fares, there's one missing value, so it's easier to just take the median value and use that as an approximate value:
dataset$fare[1226] <- median(dataset$fare, na.rm = TRUE)
For embarked
, there are two missing values. Of the three departure locations, Southampton is the most common, so let's assign that to the two missing cells:
dataset$embarked[c(169, 285)] = "S"
dataset$embarked <- factor(dataset$embarked)
Now the dataset is ready for model training in Caret
.
For a proper machine learning study, the dataset should be split into two pieces, the training set where you develop the model, and the testing set that you try to predict as accurately as possible. We will do an 80/20 split. To do so, run the following code:
If you would like, you can change the number inside set.seed()
to change the final splitting of the dataset.
# define an 80%/20% train/test split of the dataset
set.seed(381710)
split=0.80
trainIndex <- createDataPartition(dataset$passengerid, p=split, list=FALSE)
data_train <- dataset[trainIndex, ]
data_test <- dataset[-trainIndex, ]
To get a reasonable estimation of our model's error, we use the k-fold cross-validation method. For this dataset, we set $k=10$ (so, ten-fold), meaning we will have ten testing sets to iterate through. In addition, to get better statistics, we repeat the k-fold cross-validation process 3 times with different random samples. The code below sets this up for us. We also choose the "Accuracy" metric for how we want to measure our model's performance.
Note: You will see that, before every model training run, that we run set.seed(seed)
. The input variable seed
is defined below.
It is important that the seed
variable stays the same for all training runs. This has to do with the random number generator
that does the data sampling when building up the k-fold cross-validation. For accurate comparison, we want all models to train against the
same k-fold cross-validation procedures.
control <- trainControl(method="repeatedcv", number=10, repeats=3)
seed <- -4168
metric <- "Accuracy"
Now that we're all set up with a training set and a cross-validation method, we revisit the glm model and see if we can improve it at all. To train a more sophisicated glm model, we include all reasonable variables and run the following code. If you get a series of warning messages, that's okay.
set.seed(seed)
fit.glm <- train(survived~fare + age + sex + pclass + sibsp + parch + embarked,
data=data_train, method="glm", metric=metric, trControl=control)
To make predictions on the testing dataset and evaluate it's accuracy, we run the following code. It will construct a truth table that indicates the accurate predictions and the inaccurate ones.
x_test <- data_test %>% select(-survived)
y_test <- data_test %>% select(survived)
glm.predictions <- predict(fit.glm, x_test)
confusionMatrix(table(pred.glm = glm.predictions, truth = y_test$survived))
Question
What is the prediction accuracy for the model? What is the confidence interval?
The prediction accuracy for the model is 78% and the confidence interval predicts the true value betweeen 72.5% and 82.9%.
While the glm model is nice in that it ties in directly with the lm()
models we've been using,
it's not always the best model to use. An alternative is to start with decision trees, but instead of a
single one, we can instead make hundreds of decision trees and then average their parameters to get
a reaonable prediction. One version of this process is called the Random Forest. We will use
this and see if it performs better than the glm model.
The code below sets up and runs a random forest model for the same sets of parameters that we initially used for the glm model. Please note, this is a larger calculation compared with anything else we've done during the course. It will take a few minutes for this to run. It is important that you let the algorithm run to completion before you start doing additional work in the notebook, otherwise it might crash the kernel.
# Random Forest
set.seed(seed)
fit.rf <- train(survived~fare + age + sex + pclass + sibsp + parch + embarked,
data=data_train, method="rf", metric=metric, trControl=control)
For the random forest, if we want to evaluate which parameters are most important, we run the code below. The more a point slides to the right, the more important it is.
options(repr.plot.width=8, repr.plot.height=4)
varImpPlot(fit.rf$finalModel)
Question
Which variables are the most important for predicting whether or not a passenger survives?
It sems that sex, fare, and age are the most important predictors as to whether or not a passenger survives.
Just like with glm, we should make predictions on the test set to see how we've done, predicition-wise.
# make predictions
x_test <- data_test %>% select(-survived)
y_test <- data_test %>% select(survived)
rf.predictions <- predict(fit.rf, x_test)
confusionMatrix(table(pred.rf = rf.predictions, truth = y_test$survived))
Question
What is the prediction accuracy with the random forest? What is the confidence interval?
The prediction accuracy is 77.7% and the confidence interval predicts that the true value is between 72.14% and 82.6%.
Finally, we would like to know how the glm and random forest models compare with each other.
The resamples()
function from Caret
does some consistency checking and computes
summary statistics about each model so you can make a direct comparison.
The code below handles this:
results <- resamples(list(glm=fit.glm, randomForest=fit.rf))
summary(results)
A convenient method for comparison is to create a "dot-plot" of the Accuracy metric. The closer the Accuracy measure is to 1, the better it is. The Kappa parameter is another quantity you can use to compare models. To interpret the outputs:
# Dot-plot comparison
options(repr.plot.width=8, repr.plot.height=4)
dotplot(results)
Question
Is one model more accurate than the other, or are statistically the same? Do both parameters let you draw the same conclusion?
It seems that they are statistically the same even though the mean accuracy seems to be greater for the randomForest model. Both parameters do allow me to draw the same conclusion.
This concludes this tour of doing machine-learning on a real dataset. There's plenty more analysis that you could do, including checking which passengers are failing the prediction test and trying to create a new column that could account for this issue, trying out additional models, and seeing if you can make stronger conclusions about which variables are the most important and why.