Introduction to machine learning with tidymodels

Predicting the age of bats

Author

Dr Jamie Soul

Let’s look at a genomics example!

Let’s try to predict the age of bats from their skin DNA methylation data. The data is taken from:

Loading the metapackage

library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.1.0 ──
✔ broom        1.0.4     ✔ recipes      1.0.6
✔ dials        1.2.0     ✔ rsample      1.1.1
✔ dplyr        1.1.2     ✔ tibble       3.2.1
✔ ggplot2      3.4.2     ✔ tidyr        1.3.0
✔ infer        1.0.4     ✔ tune         1.1.1
✔ modeldata    1.1.0     ✔ workflows    1.1.3
✔ parsnip      1.1.0     ✔ workflowsets 1.0.1
✔ purrr        1.0.1     ✔ yardstick    1.2.0
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter()  masks stats::filter()
✖ dplyr::lag()     masks stats::lag()
✖ recipes::step()  masks stats::step()
• Use tidymodels_prefer() to resolve common conflicts.

Prepare the data

The GEOquery library allows us to download the normalised methylation beta values from NCBI GEO.

library(GEOquery)
library(tidyverse)
library(skimr)

#retrieve the dataset - note it always returns a list with one element per platform even if only one platform.
geo <- getGEO( "GSE164127")[[1]]

Genomics datasets for machine learning tend to have many variables/features e.g CpGs, genes, proteins and relatively few observations.

#have lots of cpgs
dim(exprs(geo))
[1] 37554   908

Beta values represent percentage of measured beads with that site methylated. Beta values are between 0 (completely unmethytlated ) and 1 (complely methylated). Note the data is pre-normalised for us. In best practice we’d pre-process the train and test data completely independently, i.e not normalised together at all.

We can extract the table of samples and the beta values of every CpG.

head(exprs(geo[,1:6]))
           GSM4997095 GSM4997096 GSM4997097 GSM4997098 GSM4997099 GSM4997100
cg00000165 0.09199332 0.07726343 0.10521394 0.12867552 0.07843872 0.10718643
cg00001209 0.94353508 0.95044196 0.94180974 0.94778998 0.95106877 0.95109927
cg00001364 0.91619625 0.92543450 0.91720269 0.92829502 0.91140442 0.92322497
cg00001582 0.04977606 0.05756138 0.05402703 0.06135690 0.05762858 0.05055837
cg00002920 0.88561694 0.93060932 0.93808918 0.94929825 0.93766111 0.93322898
cg00003994 0.05124363 0.05017108 0.04621871 0.03492938 0.03834341 0.04810011

We can also extract the corresponding metadata. The metadata includes the age which we are trying to predict.

skim(pData(geo))
Data summary
Name pData(geo)
Number of rows 908
Number of columns 47
_______________________
Column type frequency:
character 47
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
title 0 1.00 37 62 0 908 0
geo_accession 0 1.00 10 10 0 908 0
status 0 1.00 21 21 0 2 0
submission_date 0 1.00 11 11 0 2 0
last_update_date 0 1.00 11 11 0 2 0
type 0 1.00 7 7 0 1 0
channel_count 0 1.00 1 1 0 1 0
source_name_ch1 0 1.00 4 6 0 4 0
organism_ch1 0 1.00 13 25 0 28 0
characteristics_ch1 0 1.00 29 29 0 861 0
characteristics_ch1.1 0 1.00 6 14 0 25 0
characteristics_ch1.2 0 1.00 27 29 0 5 0
characteristics_ch1.3 0 1.00 9 18 0 175 0
characteristics_ch1.4 0 1.00 0 28 48 6 0
characteristics_ch1.5 0 1.00 0 11 48 3 0
treatment_protocol_ch1 0 1.00 26 26 0 1 0
growth_protocol_ch1 0 1.00 26 26 0 1 0
molecule_ch1 0 1.00 11 11 0 1 0
extract_protocol_ch1 0 1.00 25 25 0 1 0
label_ch1 0 1.00 11 11 0 1 0
label_protocol_ch1 0 1.00 26 26 0 1 0
taxid_ch1 0 1.00 4 6 0 28 0
hyb_protocol 0 1.00 144 144 0 1 0
scan_protocol 0 1.00 93 93 0 1 0
description 0 1.00 39 81 0 42 0
data_processing 0 1.00 71 71 0 1 0
platform_id 0 1.00 8 8 0 1 0
contact_name 0 1.00 14 14 0 1 0
contact_email 0 1.00 22 22 0 1 0
contact_laboratory 0 1.00 7 7 0 1 0
contact_department 0 1.00 14 14 0 1 0
contact_institute 0 1.00 37 37 0 1 0
contact_address 0 1.00 44 44 0 1 0
contact_city 0 1.00 11 11 0 1 0
contact_state 0 1.00 2 2 0 1 0
contact_zip/postal_code 0 1.00 10 10 0 1 0
contact_country 0 1.00 3 3 0 1 0
supplementary_file 0 1.00 109 109 0 908 0
supplementary_file.1 0 1.00 109 109 0 908 0
data_row_count 0 1.00 5 5 0 1 0
age (years):ch1 48 0.95 1 5 0 173 0
age:ch1 860 0.05 1 5 0 21 0
basename:ch1 48 0.95 19 19 0 860 0
canbeusedforagingstudies:ch1 0 1.00 2 3 0 2 0
confidenceinageestimate:ch1 0 1.00 2 3 0 5 0
Sex:ch1 0 1.00 4 6 0 2 0
tissue:ch1 48 0.95 4 6 0 4 0

Let’s keep those samples which have a known age that we can use for modelling.

geo$`age (years):ch1` <- as.numeric(geo$`age (years):ch1`)
Warning: NAs introduced by coercion
geo <- geo[ , geo$`canbeusedforagingstudies:ch1` =="yes" & geo$`tissue:ch1` == "Skin" & !is.na(geo$`age (years):ch1`)]

To make this faster to run and to show we can do ML on smaller datasets let’s use just one of the bat species to train on.

Let’s train a model using data from: Greater spear-nosed bat

To test how generalisable the model is we try to use the model across species to predict the age of:

Big brown bat

#helper function to extract a data matrix for a particular bat species
processData <- function(species,geo){
  
  geo_filtered <- geo[,geo$organism_ch1 == species]
    
  methyl_filtered  <- as.data.frame(t(exprs(geo_filtered)))
  
  methyl_filtered$age <- sqrt(as.numeric(geo_filtered$`age (years):ch1`)+1)
  
  return(methyl_filtered)
}

#Let's keep only 1k random CpGs to help training speed for this workshop
set.seed(42)
keep <- sample.int(nrow(geo),1000)
geo <- geo[keep,]

#get the data from model building and testing
methyl_spearbat <- processData("Phyllostomus hastatus",geo)
methyl_bigbrownbat <- processData("Eptesicus fuscus",geo)

Create the training split

Keeping 20% of the data for a final test. The remaining 80% will be used to train the parameters of the model.

#Split the data into train and test
methyl_spearbat_split <- initial_split(methyl_spearbat,prop = 0.8,strata=age)
Warning: The number of observations in each quantile is below the recommended threshold of 20.
• Stratification will use 3 breaks instead.
methyl_spearbat_train <- training(methyl_spearbat_split)
methyl_spearbat_test <- testing(methyl_spearbat_split)

Create the recipe

Similar to before we define the outcome and scale-centre the rest of the predictors.

#define the recipe
methyl_recipe <- 
  recipe(methyl_spearbat_train) %>%
  update_role(everything()) %>%
  update_role(age,new_role = "outcome")  %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors())

Select the model

Let’s use a GLMNet model which allows use to penalise the inclusion of variables to prevent overfitting and keep the model sparse. This is useful if we want to identify the minimal panel of biological features that are sufficient to get a good prediction e.g for a biomarker panel.

mixture = 1 is known as a lasso model. In this model we need to tune the penalty (lambda) which controls the downweighting of variables (regulatisation).

tune marks the penalty parameter as needing optimisation.

#use glmnet model
glmn_fit <- 
  linear_reg( mixture = 1, penalty = tune()) %>% 
  set_engine("glmnet") 

Let’s cross validate within the training dataset to allow us to tune the parameters

#5-fold cross validation
folds <- vfold_cv(methyl_spearbat_train, v = 5, strata = age, breaks= 2)

Create the workflow

We build the workflow from the model and the recipe.

#build the workflow
methyl_wf <- workflow() %>%
    add_model(glmn_fit) %>%
    add_recipe(methyl_recipe)

Define the tuning search space

Here we’ll check the performance of the model as we vary the penalty.

#define a sensible search
lasso_grid <- tibble(penalty = 10^seq(-3, 0, length.out = 50))

Run the tuning workflow

We can use multiple cpus to help speed up the tuning.

#Using 6 cores
library(doParallel)
cl <- makeCluster(6)
registerDoParallel(cl)

#tune the model
methyl_res <- methyl_wf %>% 
    tune_grid(resamples = folds,
              grid = lasso_grid,
              control = control_grid(save_pred = TRUE),
              metrics = metric_set(rmse))
methyl_res
# Tuning results
# 5-fold cross-validation using stratification 
# A tibble: 5 × 5
  splits          id    .metrics          .notes           .predictions      
  <list>          <chr> <list>            <list>           <list>            
1 <split [45/12]> Fold1 <tibble [50 × 5]> <tibble [0 × 3]> <tibble [600 × 5]>
2 <split [45/12]> Fold2 <tibble [50 × 5]> <tibble [0 × 3]> <tibble [600 × 5]>
3 <split [46/11]> Fold3 <tibble [50 × 5]> <tibble [0 × 3]> <tibble [550 × 5]>
4 <split [46/11]> Fold4 <tibble [50 × 5]> <tibble [0 × 3]> <tibble [550 × 5]>
5 <split [46/11]> Fold5 <tibble [50 × 5]> <tibble [0 × 3]> <tibble [550 × 5]>

How does the regularisation affect the performance?

We can find the best penalty value that minimises the error in our age prediction (rmse).

autoplot(methyl_res)

Finalise the model

Get the best model parameters

best_mod <- methyl_res %>% select_best("rmse")
best_mod
# A tibble: 1 × 2
  penalty .config              
    <dbl> <chr>                
1   0.139 Preprocessor1_Model36

Get the final model

#fit on the training data using the best parameters
final_fitted <- finalize_workflow(methyl_wf, best_mod) %>%
    fit(data = methyl_spearbat_train)

Test the performance

Look at the performance in the test dataset. How well does the clock work on a different species?

#get the test performance
methyl_spearbat_aug <- augment(final_fitted, methyl_spearbat_test)
rmse(methyl_spearbat_aug,truth = age, estimate = .pred)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard       0.535
plot(methyl_spearbat_aug$.pred,methyl_spearbat_aug$age)

#try on the different species
methyl_bigbrownbat <- augment(final_fitted, methyl_bigbrownbat)
plot(methyl_bigbrownbat$age,methyl_bigbrownbat$.pred)

rmse(methyl_bigbrownbat,truth = age, estimate = .pred)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard       0.823

What CpGs are important?

We can use the coefficients from the model to determine what CpGs are driving the prediction. We can also see how many variables have been retained in the model using our tuned penalty value.

library(vip)
library(cowplot)

#get the importance from glmnet using the select penalty
importance <- final_fitted %>%
  extract_fit_parsnip() %>%
  vi(lambda = best_mod$penalty) %>%
  mutate(
    Importance = abs(Importance),
    Variable = fct_reorder(Variable, Importance)
  )

#how many CpGs are retained
table(importance$Importance>0)

FALSE  TRUE 
  987    13 

Plot the importance of the top CpGs and their direction

#plot the top 10 CpGs
importance %>% slice_max(Importance,n=10) %>%
  ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
  geom_col() +
  scale_x_continuous(expand = c(0, 0)) +
  labs(y = NULL) + theme_cowplot()

Plot the top predictive CpG beta values versus age

This highlights how you can use machine learning to identify a small number of discrimative features.

#helper function to plot a CpG beta values against age
plotCpG <- function(cpg,dat){
  
  ggplot(dat,aes(x=!!sym(cpg),y=age)) +
    geom_point() +
    theme_cowplot()
  
}

#plot the most important CpGs
importance %>% 
  slice_max(Importance,n=4) %>%
  pull(Variable) %>% 
  as.character() %>% 
  map(plotCpG,methyl_spearbat) %>%
  plot_grid(plotlist = .)