Interpretable AI: LIME
TOC
Introduction
AI has been making transformation on different industries and it is increasingly becoming part of many modern applications. Examples are recommendation systems in healthcare, retails, and heavy industrial production. Despite advancement in those machine learning fields, a big part of machine learning models remains in the black box. Which means the factors and reasons of the AI is hidden among the neural network, ensembling method or the mathematical equations. As AI progresses into our personal lives, there is a growing need for the interpretability of those models: how do they make decision and should we trust that?
Being able to access how does the model makes predictions sheds more light into the matter, and enable better business decision making, especially in high stake areas such as finance or healthcare. If the model use the input features that makes sense, or based the prediction on relevant part of the images, we have more trust and can rely on it more. If the model just happens to use random input and get the accuracy high, there should be a warning and further investigation to improve the model before relying on it. This can lead to discovery or bugs in the logical reasoning process when we built the model. It can also uncover hidden bias so that we can fix and ensure fairness.
There are some cases that the model find a highly correlation features to predict correctly on the training set, but in truth, it doesn’t really answer the question. For example, on the training set, we have wolf’s background to have a lot of snow, and a husky’s background is in door. A possible case is that a model trained on that specific dataset would use the snow as cue to predict wolf or husky. Such an obvious observation might or might not apply well in the reality, since the model hasn’t really learned how to distinguish between two species: one in the wild and one is domesticated.
Some models are more readily and easily interpretable than others. For example, linear regression provides a intuitive and simple equations with coefficient for each input feature, toghether with its stastics so that users can get an understanding of how much those variables contribute to the final prediction. For more complex models, such as neural network, interpreting millions of coefficients (weights) for each neuron still pose a challenge. The challenging part comes from several sources: ability of the model, computing resources, and ability to explain. First, simpler and easier to explain models often don’t perform as well as complex models in reality. Since the reality we face is intricate and multi dimensional, it requires models that are more able in handling and comprehensing those stuff. Second, making interpretability requires computational resources, too, sometimes as much as the original itself. For example, DeepVis, a model that was built to visualise what neural network sees in each of its layers, takes just as much effort and innovation as building a neural network itself. Third, our ability to make use of those coefficients and weights the model returns after crunching numbers is also a contributing factor. Sometimes we can approximate the truth, sometimes not.
Despite those drawbacks, with our best effort, some methods have been developed to give insights (peeks) into why a specific decision was made, for example, why the image was classified into that class or why a loan application should be rejected. One of such is called LIME.
LIME
LIME, short for local interpretable model-agnostic explanations, is an effort to make AI more interpretable. LIME aims for explaining the prediction of any machine learning model in a way that human can easily grasp. At the heart of it, LIME approximates the complex model with a simpler and interpretable model just around the prediction point - the point we query it to explain. In doing this, we understand why did the model arrive at that prediction, which features and how much does each contribute to that prediction. The simpler model that LIME uses is usually a linear one. The building of the simpler model is with a loss function that tries to stay faithfully close to the complex model in that neighbourhood. When LIME explains individual predictions, it gives some pointers to where to investigate further if the model seems to be off. It can validate the models and enhance our understanding of the inner working of the model, too.
Before diving into each word in the name of this method, let’s analyze some of the desired characteristics that we would want when we build an explainer for a more complex model. First and foremost, the explainer needs to be interpretable. And this is not just about technicality, it is an indicator that takes into account the need of users, too. A Bayesian network might be easy for a machine learning expert, but a usual user would probably prefer a simple linear regression with a few features. It even implies that the explaining model can have different input features from the original one it is trying to explain. Second, there is a least requirement of being locally faithful to the original model. If the explainer is trying to explain a point or a prediction, it should be able to replicate the prediction of the original complex model in the neighborhood of that point, too. A stricter requirement is global fidelity, but it remains a challenge since the complete faithful model is nothing but the original model itself. We can aim for a global perspective, though. Third, a preferred characteristic would be that the explainer is model-agnostic. Which means that the method works across many types of models and doesn’t rely on some specific characteristic of the original model.
Now we can explain the name Local Interpretable Model-agnostic Explanations method to be something that works across many models. It can provide explanation by simply regressing a linear model in the neighborhood of the interested point. That simple linear model serves as a local approximation for the complex model. Here is an intuitive graph that is provided by the authors of the method.
In the image, the complex model that we trained has a complex decision function (also unknown to LIME) in blue-pink background. This decision function is non linear. But we don’t need the entire decision making function, we only care about the bold red cross point (for example, we investigate the decision, or we are testing). LIME would then sample points around that interested point, weight them by the distance to the interested point. Then it runs them through the complex model to get the predition, and finally run its own version of a linear decision boundary, approximating the complex model in that neighborhood. Specifically, it works in the following steps:
- Step 1: Choose the point to be explained
- Step 2: Pertube the sample around that point
- Step 3: Use the original complex model to label those perturbed samples.
- Step 4: Weight those data points with regard to the distance to the interested point in step 1. This is to agree with the premise that LIME aims for good local approximation. This set of weighted points would be the dataset to train the new interpretable model.
- Step 5: Train a sparse linear model on the dataset created in step 4. Here is the loss function of LIME:
x is the point to be explained. f is the model that we need to explain. g is the approximating explainer. \(\pi(x)\) is the proximity measure of point regarding x, this is to define locality around x. \(L(f,g,\pi_x)\) measures how unfaithful g is in approximating f, in the neighbourhood of x, governed by \(\pi(x)\). \(\Omega(g)\) is the complexity of g. So, our purpose is to ensure interpretability (keeping complexity small) and local fidelity. We call this loss function locality-aware loss function.
Code example
In this example, wee use the Iris flower dataset. You can train any classifier that you want to experiment on. Here we use the random forest. After running the data through LIME, information is printed out. We can see that, for the first datapoint in the test dataset, the prediction is 0.99 on versicolor (color coded orange). In the four features, the two features that contribute to this prediction is the petal length and the petal width, with petal length contributing more. The datapoint is also printed out so you can see for example the petal length of this flower is 4.7cm, the petal width is 1.2cm.
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from lime import lime_tabular
# Load dataset
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
# Train a RandomForest classifier
rf = RandomForestClassifier(random_state=42)
rf.fit(X_train, y_train)
# Use LIME to explain predictions
explainer = lime_tabular.LimeTabularExplainer(X_train, feature_names=iris.feature_names, class_names=iris.target_names, discretize_continuous=True)
exp = explainer.explain_instance(X_test[0], rf.predict_proba, num_features=4)
exp.show_in_notebook()
Let’s come back to the home credit default risk dataset. There are 300000 observations, 91.93% don’t default actually. To estimate this dataset, we first use a random forest, an accuracy of 92.00% is not quite impressive. Then we notice that the dataset has 121 features so we can use random forest to select top 10 features of this dataset. They are credit rating from 3 external sources, date of birth, the registration day, the annuity, and the number of employed days. This makes the example simpler.
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
np.set_printoptions(suppress=True)
labels = pd.read_csv('home-credit-default-risk/HomeCredit_columns_description.csv',encoding='ISO-8859-1')
data = pd.read_csv('home-credit-default-risk/application_train.csv')
y_train_orig = data['TARGET']
X_train = data.drop(['TARGET'], axis=1)
y_train_orig = y_train_orig.to_frame()
# PREPROCESS
numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
categoricals = ['object']
X_train_categorical = X_train.select_dtypes(include=categoricals)
X_train_numerical = X_train.select_dtypes(include=numerics)
categorical_columns = X_train_categorical.columns
numerical_columns = X_train_numerical.columns
categorical_imputer = SimpleImputer(missing_values=np.nan, strategy='most_frequent')
numerical_imputer = SimpleImputer(missing_values=np.nan, strategy='mean')
X_train_categorical = categorical_imputer.fit_transform(X_train_categorical)
X_train_categorical = pd.DataFrame(data=X_train_categorical, columns=categorical_columns)
X_train_numerical = numerical_imputer.fit_transform(X_train_numerical)
X_train_numerical = pd.DataFrame(data=X_train_numerical, columns=numerical_columns)
X_train_categorical = X_train_categorical.apply(LabelEncoder().fit_transform)
X_train_new = pd.concat([X_train_numerical, X_train_categorical], axis=1)
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_train_new, y_train_orig, test_size=0.33, random_state=42)
# TRAIN
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
predicted = rf.predict(X_test)
print(accuracy_score(y_test, predicted))
/var/folders/kf/5_ggvsz93vxdbx_h0tvy66xh0000gn/T/ipykernel_64417/3790117361.py:9: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().
rf.fit(X_train, y_train)
0.9200228618728998
importances = rf.feature_importances_
# Get the top 10 features
feature_importances = pd.DataFrame({'feature': list(X_train.columns), 'importance': importances})
feature_importances = feature_importances.sort_values('importance', ascending=False).reset_index(drop=True)
print(feature_importances.head(10))
/var/folders/kf/5_ggvsz93vxdbx_h0tvy66xh0000gn/T/ipykernel_64417/2569292954.py:6: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().
rf.fit(X_train, y_train)
feature importance
0 EXT_SOURCE_2 0.056362
1 EXT_SOURCE_3 0.054895
2 DAYS_BIRTH 0.035585
3 DAYS_ID_PUBLISH 0.034736
4 DAYS_REGISTRATION 0.030249
5 EXT_SOURCE_1 0.029757
6 SK_ID_CURR 0.029627
7 AMT_ANNUITY 0.029011
8 DAYS_LAST_PHONE_CHANGE 0.028513
9 DAYS_EMPLOYED 0.028438
features=feature_importances['feature'][:10]
X_train_10 = X_train[features]
X_test_10 = X_test[features]
# TRAIN with less features
rf = RandomForestClassifier()
rf.fit(X_train_10, y_train)
predicted = rf.predict(X_test_10)
print(accuracy_score(y_test, predicted))
/var/folders/kf/5_ggvsz93vxdbx_h0tvy66xh0000gn/T/ipykernel_64417/2994904051.py:3: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().
rf.fit(X_train_10, y_train)
0.9199735905950985
# EXPLAIN
explainer = lime_tabular.LimeTabularExplainer(X_train_10.values, feature_names=X_train_10.columns, class_names=['Paid', 'Default'], discretize_continuous=True)
i = 0 # index of the instance to be explained
exp = explainer.explain_instance(X_train_10.values[i], rf.predict_proba, num_features=10)
exp.show_in_notebook()
/Users/nguyenlinhchi/.local/lib/python3.9/site-packages/sklearn/base.py:439: UserWarning: X does not have valid feature names, but RandomForestClassifier was fitted with feature names
warnings.warn(
Here we ask LIME to analyze the prediction of random forest (RF) on the first datapoint. The model predict a non-default rate of 0.97 despite the external source 3, annuity, and days employed all prefer the prediction that this customer would default. If we look into the data, we see that the first datapoint defaulted. For this simplified example, we conclude that our model (RF) is not very good from the beginning, but we use it for the sake of simplicity. Our main goal is to see how LIME would analyze this model and this dataset. The analyzing is quite informative, it gives the color coded explanation for each category.
LIME doesn’t just work for tabular data, it also works for images. Here is a classic example that LIME evaluates Google’s pretrained Inception model, on predicting a cat. The green area contributes to the prediction and the red area advises against it:
Submodular pick
Apart from explaining only one datapoint, the authors also provide submodular pick method to explain the model as a whole. This is called the global perspective. The intution is that, to explain the model, we need to pick a diverse and representative set of points. And we pick those points using the explanation provided by LIME. The following image provides a simple illustration for this method:
To choose features, we score the importance point. The way we score the importance should reflect the fact that feature f2 can explain much more data than feature f1 (f2 contributes to the first 4 rows of the dataset, meanwhile f1 contributes to only 1 row). So I(f2) > I (f1). To reduce the dataset, we proceed to screen the rows of the table. Notice that choosing row 2 and row 3 is the same in term of information, since they have similar explanation (both can be explained by feature f2 and f3). We shouldn’t choose datapoints that are too similar to each other. The resulting procedure ends up choosing row 2 and row 5, since together they contribute the most to the diversity of the dataset (making use of 4 features out of 5).
Applying this submodular pick method to the home credit default rate above, the method picks 57 features and 20 instances for its explanation. 20 seems to be a small amount of data to represent the whole dataset.
from lime import submodular_pick
explainer = lime.lime_tabular.LimeTabularExplainer(np.array(X_train), feature_names=X_train.columns, class_names=['default'], verbose=False, mode='regression',discretize_continuous=False)
sp_obj = submodular_pick.SubmodularPick(explainer, np.array(X_train), rf.predict, sample_size=20, num_features=14, num_exps_desired=5)
import pandas as pd
W=pd.DataFrame([dict(this.as_list()) for this in sp_obj.explanations])
W.head()
SK_ID_CURR | CNT_CHILDREN | AMT_INCOME_TOTAL | AMT_CREDIT | AMT_ANNUITY | AMT_GOODS_PRICE | REGION_POPULATION_RELATIVE | DAYS_BIRTH | DAYS_EMPLOYED | DAYS_REGISTRATION | ... | FLOORSMAX_MODE | LIVE_CITY_NOT_WORK_CITY | NAME_TYPE_SUITE | AMT_REQ_CREDIT_BUREAU_QRT | OBS_60_CNT_SOCIAL_CIRCLE | OBS_30_CNT_SOCIAL_CIRCLE | FLAG_OWN_CAR | NAME_INCOME_TYPE | AMT_REQ_CREDIT_BUREAU_YEAR | AMT_REQ_CREDIT_BUREAU_WEEK | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.000000 | 0.00000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.00000 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | NaN | NaN | NaN | 0.000189 | -0.000081 | 0.000204 | 0.000327 | -0.000167 | NaN | -0.00019 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
2 | 0.000000 | 0.00000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.00000 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
3 | 0.000403 | 0.00019 | NaN | NaN | NaN | NaN | NaN | 0.000131 | NaN | NaN | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
4 | 0.000000 | 0.00000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 | 0.00000 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
5 rows × 45 columns
W.shape
(20, 57)
Conclusion
In an era where AI is increasingly woven into the fabric of our daily lives, interpretability becomes even more critical. From recommendation engines to credit risk modeling, autonomous driving to healthcare diagnostics, AI systems are making decisions that have profound implications on individuals and societies. Therefore, the requirement for these systems to be ‘friendly’, transparent, and understandable continues to grow day by day.
This blog post has not only discussed the theoretical foundations of LIME but has also provided a practical demonstration of its application in deciphering a machine learning model trained on the home credit default risk dataset. The post has highlighted how LIME can serve as a valuable tool in the machine learning practitioner’s toolbox, contributing significantly to the understanding of complex models.
The exercise illustrated here underscores the fact that LIME is not just an adjunct to the model, but in many ways, an essential complement to it. By providing local interpretability, it adds a layer of transparency that helps build trust in the model’s decisions.
In conclusion, while the path to fully interpretable AI may still be long and fraught with challenges, tools like LIME are helping us navigate this path by making our models a little less like a black box and a little more like a glass box. It is my hope that this post has added value to your understanding of LIME and its applications, and that it will aid in your future machine learning modeling work. As we continue to embrace AI and machine learning, striving for clarity and interpretability is more than a luxury - it is a necessity.