# Explainable AI with Synthetic Data <a href="https://colab.research.google.com/github/mostly-ai/mostlyai/blob/main/docs/tutorials/explainable-ai/explainable-ai.ipynb" target="_blank"><img src="https://img.shields.io/badge/Open%20in-Colab-blue?logo=google-colab" alt="Run on Colab"></a>

In this notebook, we demonstrate how a ML model, that was trained on real data, can be perfectly explored, reasoned around and validated in great detail with synthetic data. As synthetic data is not restricted by privacy, this allows to engage with far broader groups and communities, when it comes to algorithmic auditing, and ensuring the safety of developed ML-powered systems.

For further background see also [this blog post](https://mostly.ai/blog/the-future-of-explainable-ai-rests-upon-synthetic-data/) on "_The Future of Explainable AI rests upon Synthetic Data_".

## Train ML Model on Real Data

Let's again use the UCI Adult [[1](#refs)] dataset, consisting of 48,842 records across 14 attributes. We will initially remove 4 attributes to make the analysis later on more insightful.

In [None]:
%pip install -U mostlyai  # or: pip install -U 'mostlyai[local]'
%pip install scikit-learn seaborn lightgbm shap

In [None]:
import pandas as pd
import numpy as np

# fetch original data and drop a couple of variables
df = pd.read_csv("https://github.com/mostly-ai/public-demo-data/raw/dev/census/census.csv.gz")
df = df.drop(columns=["fnlwgt", "education_num", "race", "native_country"])
df.head(5)

In [None]:
import lightgbm as lgb
from lightgbm import early_stopping
from sklearn.model_selection import train_test_split

target_col = "income"
target_val = ">50K"


def prepare_xy(df):
    y = (df[target_col] == target_val).astype(int)
    str_cols = [col for col in df.select_dtypes(["object", "string"]).columns if col != target_col]
    for col in str_cols:
        df[col] = pd.Categorical(df[col])
    cat_cols = [col for col in df.select_dtypes("category").columns if col != target_col]
    num_cols = [col for col in df.select_dtypes("number").columns if col != target_col]
    for col in num_cols:
        df[col] = df[col].astype("float")
    X = df[cat_cols + num_cols]
    return X, y


def train_model(X, y):
    cat_cols = list(X.select_dtypes("category").columns)
    X_trn, X_val, y_trn, y_val = train_test_split(X, y, test_size=0.2, random_state=1)
    ds_trn = lgb.Dataset(X_trn, label=y_trn, categorical_feature=cat_cols, free_raw_data=False)
    ds_val = lgb.Dataset(X_val, label=y_val, categorical_feature=cat_cols, free_raw_data=False)
    model = lgb.train(
        params={"verbose": -1, "metric": "auc", "objective": "binary"},
        train_set=ds_trn,
        valid_sets=[ds_val],
        callbacks=[early_stopping(5)],
    )
    return model


import warnings

warnings.filterwarnings("ignore")

In [None]:
X, y = prepare_xy(df)
model = train_model(X, y)

## Synthesize Data via MOSTLY AI

The code below will automatically create a synthetic dataset using the Synthetic Data SDK.

In [None]:
from mostlyai.sdk import MostlyAI

# initialize SDK
mostly = MostlyAI()

In [None]:
# train a generator on the original training data
g = mostly.train(
    config={
        "name": "Explainable AI Tutorial Census",
        "tables": [
            {
                "name": "data",
                "data": df,
                "tabular_model_configuration": {"max_training_time": 1},
            }
        ],
    }
)

In [None]:
# generate a synthetic dataset
syn = mostly.generate(g).data()
print(f"Created synthetic data with {syn.shape[0]:,} records and {syn.shape[1]:,} attributes")

## Evaluate ML Performance on Synthetic

This is also known as a Train-Real-Test-Synthetic (TRTS) approach.

In [None]:
from sklearn.metrics import roc_auc_score, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["figure.dpi"] = 72

X_syn, y_syn = prepare_xy(syn)
p_syn = model.predict(X_syn)
auc = roc_auc_score(y_syn, p_syn)
acc = accuracy_score(y_syn, (p_syn >= 0.5).astype(int))
probs_df = pd.concat(
    [
        pd.Series(p_syn, name="probability").reset_index(drop=True),
        pd.Series(y_syn, name=target_col).reset_index(drop=True),
    ],
    axis=1,
)
fig = sns.displot(data=probs_df, x="probability", hue=target_col, bins=20, multiple="stack")
fig = plt.title(f"Accuracy: {acc:.1%}, AUC: {auc:.1%}", fontsize=20)
plt.show()

## Explain ML Model with Synthetic

We will be using the **SHAP library**, a state-of-the-art Python package for Explainable AI. Learn more about SHAP and XAI at https://shap-lrjball.readthedocs.io/ and https://christophm.github.io/interpretable-ml-book/.

### Calculate SHAP values for trained model w/ synthetic data

Important: For this step no access to the original (privacy-sensitive) training data is needed anymore. We only need access to the trained model for inference, as well as to representative synthetic data.

In [None]:
import shap

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_syn)

import warnings

warnings.filterwarnings("ignore")

### SHAP Feature Importance

In [None]:
shap.initjs()
shap.summary_plot(shap_values, X_syn, plot_size=0.2, plot_type="bar")

### SHAP Dependency Plots

Let's study the value-dependent impact of each model feature. SHAP dependency plots are a great way to do so, as they not only show the average lift, but also the level of variance of that lift, at the same time.

In [None]:
def plot_shap_dependency(col):
    col_idx = [i for i in range(X_syn.shape[1]) if X_syn.columns[i] == col][0]
    # Check if shap_values is 1-dimensional
    if len(shap_values.shape) == 1:
        shp_vals = pd.Series(shap_values, name="shap_value")
    else:
        shp_vals = pd.Series(shap_values[:, col_idx], name="shap_value")

    col_vals = X_syn.iloc[:, col_idx].reset_index(drop=True)
    df = pd.concat([shp_vals, col_vals], axis=1)

    if col_vals.dtype.name != "category":
        q01 = df[col].quantile(0.01)
        q99 = df[col].quantile(0.99)
        df = df.loc[(df[col] >= q01) & (df[col] <= q99), :]
    else:
        sorted_cats = list(df.groupby(col)["shap_value"].mean().sort_values().index)
        df[col] = df[col].cat.reorder_categories(sorted_cats, ordered=True)

    fig, ax = plt.subplots(figsize=(8, 4))
    plt.ylim(-3.2, 3.2)
    plt.title(col)
    plt.xlabel("")
    if col_vals.dtype.name == "category":
        plt.xticks(rotation=90)
    ax.tick_params(axis="both", which="major", labelsize=8)
    ax.tick_params(axis="both", which="minor", labelsize=6)
    sns.lineplot(x=df[col], y=df["shap_value"], color="black").axhline(0, color="gray", alpha=1, lw=0.5)
    sns.scatterplot(x=df[col], y=df["shap_value"], alpha=0.1)

In [None]:
plot_shap_dependency("age")

In [None]:
plot_shap_dependency("marital_status")

In [None]:
plot_shap_dependency("hours_per_week")

### SHAP Values for Single (Synthetic) Samples

Let's study the factors behind the model scores by inspecting individual samples. Note, that this level of reasoning at an individual-level would not be possible with real data, if that consists of privacy-sensitive information. However, synthetic data allows to reason around samples at any scale.

In [None]:
def show_idx(i):
    shap.initjs()
    df = X_syn.iloc[i : i + 1, :]
    df.insert(0, "actual", y_syn.iloc[i])
    df.insert(1, "score", p_syn[i])
    display(df)

    # Determine the correct expected value and shap values based on the structure of explainer.expected_value
    if isinstance(explainer.expected_value, np.ndarray) and explainer.expected_value.ndim > 0:
        # For multi-class classifiers
        expected_value = explainer.expected_value[1]
        shap_val = shap_values[1][i] if isinstance(shap_values[1], np.ndarray) else shap_values[1]
    else:
        # For binary classifiers or regressors
        expected_value = explainer.expected_value
        shap_val = shap_values[i]

    return shap.force_plot(expected_value, shap_val, X_syn.iloc[i, :], link="logit")

#### Random Sample

In [None]:
rnd_idx = X_syn.sample().index[0]
show_idx(rnd_idx)

#### Sample with lowest score

In [None]:
idx = np.argsort(p_syn)[0]
show_idx(idx)

#### Sample with highest score

In [None]:
idx = np.argsort(p_syn)[-1]
show_idx(idx)

#### Sample with a median score

In [None]:
idx = np.argsort(p_syn)[int(len(p_syn) / 2)]
show_idx(idx)

#### Sample a young Female Doctor

In [None]:
idx = syn[(syn.education == "Doctorate") & (syn.sex == "Female") & (syn.age <= 30)].sample().index[0]
show_idx(idx)

### Explore SHAP Values across a thousand samples

In [None]:
shap.initjs()

# Determine the correct expected value
if isinstance(explainer.expected_value, np.ndarray):
    expected_value = explainer.expected_value[1]
else:
    expected_value = explainer.expected_value

# Adjust how you access shap_values based on its structure
if isinstance(shap_values, list) and len(shap_values) > 1:
    shap_vals_subset = shap_values[1][:1000, :]
else:
    shap_vals_subset = shap_values[:1000, :]

shap.force_plot(expected_value, shap_vals_subset, X.iloc[:1000, :], link="logit")

## Conclusion

This tutorial demonstrated how ML models, that have been trained on real data, can be safely tested and explained with synthetic data. As the latter is not privacy-sensitive, this allows to have these types of introspections and validations performed by a significantly larger group of stakeholders. This is a key part to build safe & smart algorithms, that will have a significant impact on individuals' lives.

## Further exercises

In addition to walking through the above instructions, we suggest..
* replicating the explainability section with real data and compare results
* using a different dataset, eg. the UCI bank-marketing dataset [[2](#refs)]
* using a different ML model, eg. a RandomForest model [[3](#refs)]

## References<a class="anchor" name="refs"></a>

1. https://archive.ics.uci.edu/ml/datasets/adult
1. https://archive.ics.uci.edu/ml/datasets/bank+marketing
1. https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html