The tidymodels packages loads a set of modular packages that we will use to build a machine learning workflow - from preparing the data to assessing the performance.
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
✖ recipes::step() masks stats::step()
• Learn how to get started at https://www.tidymodels.org/start/
Example small classification problem
Let’s cover the basic principles with an example medical dataset looking to see if we can predict patients who have stroke from life style variables.
Note
Exploratory data analysis is a critical step is any data science project.
Here we use the skimr package to get an overview of the dataframe which quickly highlight the BMI column has missing values.
#load the needed librarieslibrary(tidyverse)library(janitor)library(skimr)library(MLDataR)#explicitly call the built in data#warning - this dataset is only chosen for illustration purposes#see bmi,smoking status versus agedata("stroke_classification")#janitor is useful to make the column names tidystroke_classification <-clean_names(stroke_classification)stroke_classification <- stroke_classification[ stroke_classification$gender %in%c("Male","Female"),]#make the primary outcome a factorstroke_classification$stroke <-as.factor(stroke_classification$stroke)#Good idea to take a look at the data!skim(stroke_classification)
Data summary
Name
stroke_classification
Number of rows
5109
Number of columns
11
_______________________
Column type frequency:
character
1
factor
1
numeric
9
________________________
Group variables
None
Variable type: character
skim_variable
n_missing
complete_rate
min
max
empty
n_unique
whitespace
gender
0
1
4
6
0
2
0
Variable type: factor
skim_variable
n_missing
complete_rate
ordered
n_unique
top_counts
stroke
0
1
FALSE
2
0: 4860, 1: 249
Variable type: numeric
skim_variable
n_missing
complete_rate
mean
sd
p0
p25
p50
p75
p100
hist
pat_id
0
1.00
2555.39
1475.40
1.00
1278.00
2555.00
3833.00
5110.00
▇▇▇▇▇
age
0
1.00
43.23
22.61
0.08
25.00
45.00
61.00
82.00
▅▆▇▇▆
hypertension
0
1.00
0.10
0.30
0.00
0.00
0.00
0.00
1.00
▇▁▁▁▁
heart_disease
0
1.00
0.05
0.23
0.00
0.00
0.00
0.00
1.00
▇▁▁▁▁
work_related_stress
0
1.00
0.16
0.37
0.00
0.00
0.00
0.00
1.00
▇▁▁▁▂
urban_residence
0
1.00
0.51
0.50
0.00
0.00
1.00
1.00
1.00
▇▁▁▁▇
avg_glucose_level
0
1.00
106.14
45.29
55.12
77.24
91.88
114.09
271.74
▇▃▁▁▁
bmi
201
0.96
28.89
7.85
10.30
23.50
28.10
33.10
97.60
▇▇▁▁▁
smokes
0
1.00
0.63
0.48
0.00
0.00
1.00
1.00
1.00
▅▁▁▁▇
Split into test and training
We want the model to generalise to new unseen data, so we split our dataset into a training and test dataset. We’ll fit the model on the training data then evaluate the performance on the unseen test data
#Need to set the seed to be reproducibleset.seed(42)#save 25% of the data for testing the performance of the modeldata_split <-initial_split(stroke_classification, prop =0.75)#get the train and test datasetsstroke_train <-training(data_split)stroke_test <-testing(data_split)head(stroke_train)
library(recipes)#set the base recipe - use stroke as the outcome and the rest of the data as predictorsstroke_rec <-recipe(stroke ~ ., data = stroke_train)stroke_rec
This is a fundamental example of data leakage where the there is numeric patient ID column that is completely sufficient to distinguish between our outcome of interest. Often it is more subtle - see (Whalen et al. 2021)
Having spotted the problem now we can specify that this column should be used only as an ID column. We could have just removed this column, but it is useful to keep track of individual observations in the modelling steps.
We may have missing data in more or one of our predictors. This can be a big problem is fields such as proteomics, where the missingness may relate to the of interest outcome itself.
The choice of model depends on your application and type of data. Starting with a simple model is usually a good option to set a baseline for performance.
Here we’ll choose glm as we have binary outcome data with a handful of predictors. Later we’ll talk about what to do in genomics applications where we may have thousands of predictors.
extract_fit_parsnip allows us to get the underlying fitted model from a workflow and tidy from the broom package gives us a nicely formatted tibble.
Different models have different ways of interpreting the importance of the variables. Here we can look at the significance of the coefficients and see that age and avg_glucose_level are positively associated with a stroke in the training set.
The yardstick package has all lots of functions related to assessing how well a model is performing. To calculate the accuracy of the model of the test data we used the know outcome skroke or not truth and the predicted outcome of the model estimate.
library(yardstick)#The accuracy is really high!accuracy(stroke_aug, truth=stroke,estimate=.pred_class)
It is useful to understand where the model is making mistakes. What does the confusion matrix look like?
#ah!conf_mat(stroke_aug,stroke, .pred_class)
Truth
Prediction 0 1
0 1217 61
1 0 0
We’ve created a model which has predicted every patient hasn’t had a stroke! This is likely to because the number of observed strokes in the dataset is very low so a model which simply predicts no one has had a stroke performs very well as judged by accuracy alone.
Note
Accuracy is a poor metric to use on datasets with class imbalance.
Look at the AUC
Instead of using the class predictions we can instead use a metric ROC AUC that looks at the ranks of patient probabilities of having a stroke.
The probably package allows us to iterate through many thresholds of the class boundary and look at the trade off between sensitivity and specificity.
library(probably)threshold_data <- stroke_aug %>%threshold_perf(stroke, .pred_0, thresholds =seq(0.7, 1, by =0.01))
The J index is one way of choosing a threshold and is defined as specificity + sensitivity -1 We can plot the data to see the relationship.
max_j_index_threshold <- threshold_data %>%filter(.metric =="j_index") %>%filter(.estimate ==max(.estimate)) %>%pull(.threshold)ggplot(threshold_data, aes(x = .threshold, y = .estimate, color = .metric)) +geom_line() +geom_vline(xintercept = max_j_index_threshold, alpha = .6, color ="grey30") +theme_cowplot()
Take homes
Check your data, particularly if not your own
Split into training and testing appropriately
Watch out for data leakage and class imbalance
Choose the appropriate metric for performance, thinking what the model will be used for.
References
Whalen, Sean, Jacob Schreiber, William S. Noble, and Katherine S. Pollard. 2021. “Navigating the Pitfalls of Applying Machine Learning in Genomics.”Nature Reviews Genetics 23 (3): 169–81. https://doi.org/10.1038/s41576-021-00434-9.