Exploring Models with lime

Recently at work I’ve been asked to help some clinicians understand why my risk model classifies specific patients as high risk. Just prior to this work I stumbled across the work of some data scientists at the University of Washington called lime. LIME stands for “Local Interpretable Model-Agnostic Explanations�. The idea is that I can answer those questions I’m getting from clinicians for a specific patient by locally fitting a linear (aka “interpretable�) model in the parameter space just around my data point. I decided to pursue lime as a solution and the last few months I’ve been focusing on implementing this explainer for my risk model. Happily, I also discovered an R package that implements this solution that originated in python.

Sample Data

So the first step to this blog was to find some public data for illustration. I remembered an example used in an Introduction to Statistical Learning by James, Witten, Hastie and Tibshirani. I will use the Heart.csv data which can be downloaded using the link in the code below:

1
2
3
4
5
6
7
library(readr)
library(ranger)
library(tidyverse)
library(lime)

dat <- read_csv("http://www-bcf.usc.edu/~gareth/ISL/Heart.csv")
dat$X1 <- NULL

Now let’s take a quick look at the data:

1
Hmisc::describe(dat)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
## dat 
## 
## 14 Variables 303 Observations
## ---------------------------------------------------------------------------
## Age 
## n missing distinct Info Mean Gmd .05 .10 
## 303 0 41 0.999 54.44 10.3 40 42 
## .25 .50 .75 .90 .95 
## 48 56 61 66 68 
## 
## lowest : 29 34 35 37 38, highest: 70 71 74 76 77
## ---------------------------------------------------------------------------
## Sex 
## n missing distinct Info Sum Mean Gmd 
## 303 0 2 0.653 206 0.6799 0.4367 
## 
## ---------------------------------------------------------------------------
## ChestPain 
## n missing distinct 
## 303 0 4 
## 
## Value asymptomatic nonanginal nontypical typical
## Frequency 144 86 50 23
## Proportion 0.475 0.284 0.165 0.076
## ---------------------------------------------------------------------------
## RestBP 
## n missing distinct Info Mean Gmd .05 .10 
## 303 0 50 0.995 131.7 19.41 108 110 
## .25 .50 .75 .90 .95 
## 120 130 140 152 160 
## 
## lowest : 94 100 101 102 104, highest: 174 178 180 192 200
## ---------------------------------------------------------------------------
## Chol 
## n missing distinct Info Mean Gmd .05 .10 
## 303 0 152 1 246.7 55.91 175.1 188.8 
## .25 .50 .75 .90 .95 
## 211.0 241.0 275.0 308.8 326.9 
## 
## lowest : 126 131 141 149 157, highest: 394 407 409 417 564
## ---------------------------------------------------------------------------
## Fbs 
## n missing distinct Info Sum Mean Gmd 
## 303 0 2 0.379 45 0.1485 0.2538 
## 
## ---------------------------------------------------------------------------
## RestECG 
## n missing distinct Info Mean Gmd 
## 303 0 3 0.76 0.9901 1.003 
## 
## Value 0 1 2
## Frequency 151 4 148
## Proportion 0.498 0.013 0.488
## ---------------------------------------------------------------------------
## MaxHR 
## n missing distinct Info Mean Gmd .05 .10 
## 303 0 91 1 149.6 25.73 108.1 116.0 
## .25 .50 .75 .90 .95 
## 133.5 153.0 166.0 176.6 181.9 
## 
## lowest : 71 88 90 95 96, highest: 190 192 194 195 202
## ---------------------------------------------------------------------------
## ExAng 
## n missing distinct Info Sum Mean Gmd 
## 303 0 2 0.66 99 0.3267 0.4414 
## 
## ---------------------------------------------------------------------------
## Oldpeak 
## n missing distinct Info Mean Gmd .05 .10 
## 303 0 40 0.964 1.04 1.225 0.0 0.0 
## .25 .50 .75 .90 .95 
## 0.0 0.8 1.6 2.8 3.4 
## 
## lowest : 0.0 0.1 0.2 0.3 0.4, highest: 4.0 4.2 4.4 5.6 6.2
## ---------------------------------------------------------------------------
## Slope 
## n missing distinct Info Mean Gmd 
## 303 0 3 0.798 1.601 0.6291 
## 
## Value 1 2 3
## Frequency 142 140 21
## Proportion 0.469 0.462 0.069
## ---------------------------------------------------------------------------
## Ca 
## n missing distinct Info Mean Gmd 
## 299 4 4 0.783 0.6722 0.9249 
## 
## Value 0 1 2 3
## Frequency 176 65 38 20
## Proportion 0.589 0.217 0.127 0.067
## ---------------------------------------------------------------------------
## Thal 
## n missing distinct 
## 301 2 3 
## 
## Value fixed normal reversable
## Frequency 18 166 117
## Proportion 0.060 0.551 0.389
## ---------------------------------------------------------------------------
## AHD 
## n missing distinct 
## 303 0 2 
## 
## Value No Yes
## Frequency 164 139
## Proportion 0.541 0.459
## ---------------------------------------------------------------------------

Our target variable in this data is AHD. This flag identifies whether or not a patient has Coronary Artery Disease. If we can predict this accurately, clinicians could probably better treat these patients and hopefully help them avoid the symptoms of AHD like chest pain or worse, heart attacks.

Data Wrangling

For a predictive model I’ve opted to use a random forest model using the ranger implmentation which parallelizes the random forests algorithm in R. But first, some data cleaning is necessary. After replacing missing values, I’m going to split the data into test and training dataframes.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Replace missing values
dat$Ca[is.na(dat$Ca)] <- -1
dat$Thal[is.na(dat$Thal)] <- "missing"

## 75% of the sample size
smp_size <- floor(0.75 * nrow(dat))

## set the seed to make your partition reproducible
set.seed(123)
train_ind <- sample(seq_len(nrow(dat)), size = smp_size)

train <- dat[train_ind, ]
test <- dat[-train_ind, ]

mod <- ranger(AHD~., data=train, probability = TRUE, importance = "permutation")

mod$prediction.error
1
## [1] 0.1326235

Our quick and dirty check of the OOB prediction error tells us that our model appears to be doing okay at predicting AHR. Now the trick is to describe to our physicians and nurses why we believe someone is high risk for AHR. Before I learned of lime, I would have probably done something similar to the code below by first looking at which variables were most important in my trees.

1
2
3
4
5
6
7
8
9
10
11
plot_importance <- function(mod){
 tmp <- mod$variable.importance
 dat <- data.frame(variable=names(tmp),importance=tmp)
 ggplot(dat, aes(x=reorder(variable,importance), y=importance))+ 
 geom_bar(stat="identity", position="dodge")+ coord_flip()+
 ylab("Variable Importance")+
 xlab("")
}

# Plot the variable importance
plot_importance(mod)

After this, I probably would have taken a look at some partial dependence plots to get an idea of how those important variables are changing over the range of that variable. However, often the weakness of this approach is that I need to hold all other variables constant. And if I truly believe there are interactions between my variables, the partial dependence plot could change dramatically when other variables are changed.

Explain the model with LIME

Enter lime. As discussed above, the entire purpose of lime is to provide a local interpretable model to help us understand how our prediction would change if we tweak the other variables slightly in a lot of permutations. The first step to using lime in this specific case is to add some functions so that the lime package knows how to deal with the output of the ranger package. Once I have these I can use the combination of the lime() and explain() functions to get what I need. As in all multivariate linear models, we still have an issue… correlated explanatory varaibles. And depending on the number of variables in our original model, we may need to pair down our models to only look at the most “influential� or “important� variables. By default lime is going to use either forward-selection or pick the variables with the larges coefficients after correcting for multicollinearity using ridge regression or L2 penalization. As seen below, you can also select variables for the explanation using Lasso (aka L1 penalization) or use xgboost most important variables using the "tree" method.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Train LIME Explainer
expln <- lime(train, model = mod)


preds <- predict(mod,train,type = "response")
# Add ranger to LIME
predict_model.ranger <- function(x, newdata, type, ...) {
 res <- predict(x, data = newdata, ...)
 switch(
 type,
 raw = data.frame(Response = ifelse(res$predictions[,"Yes"] >= 0.5,"Yes","No"), stringsAsFactors = FALSE),
 prob = as.data.frame(res$predictions[,"Yes"], check.names = FALSE)
 )
}

model_type.ranger <- function(x, ...) 'classification'


reasons.forward <- explain(x=test[,names(test)!="AHD"], explainer=expln, n_labels = 1, n_features = 4)
reasons.ridge <- explain(x=test[,names(test)!="AHD"], explainer=expln, n_labels = 1, n_features = 4, feature_select = "highest_weights")
reasons.lasso <- explain(x=test[,names(test)!="AHD"], explainer=expln, n_labels = 1, n_features = 4, feature_select = "lasso_path")
reasons.tree <- explain(x=test[,names(test)!="AHD"], explainer=expln, n_labels = 1, n_features = 4, feature_select = "tree")

Note: Using the current version of lime you may have issues with the feature_select = "lasso_path" option. To get the above code to run above you can install my tweaked version of lime here.

Related