{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uUvsR-mWBoNS"
},
"source": [
"# Explainable AI with Synthetic Data
\n",
"\n",
"In this notebook, we demonstrate how a ML model, that was trained on real data, can be perfectly explored, reasoned around and validated in great detail with synthetic data. As synthetic data is not restricted by privacy, this allows to engage with far broader groups and communities, when it comes to algorithmic auditing, and ensuring the safety of developed ML-powered systems.\n",
"\n",
"For further background see also [this blog post](https://mostly.ai/blog/the-future-of-explainable-ai-rests-upon-synthetic-data/) on \"_The Future of Explainable AI rests upon Synthetic Data_\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train ML Model on Real Data\n",
"\n",
"Let's again use the UCI Adult [[1](#refs)] dataset, consisting of 48,842 records across 14 attributes. We will initially remove 4 attributes to make the analysis later on more insightful."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -U mostlyai # or: pip install -U 'mostlyai[local]'\n",
"%pip install scikit-learn seaborn lightgbm shap"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# fetch original data and drop a couple of variables\n",
"df = pd.read_csv(\"https://github.com/mostly-ai/public-demo-data/raw/dev/census/census.csv.gz\")\n",
"df = df.drop(columns=[\"fnlwgt\", \"education_num\", \"race\", \"native_country\"])\n",
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import lightgbm as lgb\n",
"from lightgbm import early_stopping\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"target_col = \"income\"\n",
"target_val = \">50K\"\n",
"\n",
"\n",
"def prepare_xy(df):\n",
" y = (df[target_col] == target_val).astype(int)\n",
" str_cols = [col for col in df.select_dtypes([\"object\", \"string\"]).columns if col != target_col]\n",
" for col in str_cols:\n",
" df[col] = pd.Categorical(df[col])\n",
" cat_cols = [col for col in df.select_dtypes(\"category\").columns if col != target_col]\n",
" num_cols = [col for col in df.select_dtypes(\"number\").columns if col != target_col]\n",
" for col in num_cols:\n",
" df[col] = df[col].astype(\"float\")\n",
" X = df[cat_cols + num_cols]\n",
" return X, y\n",
"\n",
"\n",
"def train_model(X, y):\n",
" cat_cols = list(X.select_dtypes(\"category\").columns)\n",
" X_trn, X_val, y_trn, y_val = train_test_split(X, y, test_size=0.2, random_state=1)\n",
" ds_trn = lgb.Dataset(X_trn, label=y_trn, categorical_feature=cat_cols, free_raw_data=False)\n",
" ds_val = lgb.Dataset(X_val, label=y_val, categorical_feature=cat_cols, free_raw_data=False)\n",
" model = lgb.train(\n",
" params={\"verbose\": -1, \"metric\": \"auc\", \"objective\": \"binary\"},\n",
" train_set=ds_trn,\n",
" valid_sets=[ds_val],\n",
" callbacks=[early_stopping(5)],\n",
" )\n",
" return model\n",
"\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X, y = prepare_xy(df)\n",
"model = train_model(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MKRa93uuSqZS"
},
"source": [
"## Synthesize Data via MOSTLY AI\n",
"\n",
"The code below will automatically create a synthetic dataset using the Synthetic Data SDK."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.sdk import MostlyAI\n",
"\n",
"# initialize SDK\n",
"mostly = MostlyAI()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# train a generator on the original training data\n",
"g = mostly.train(\n",
" config={\n",
" \"name\": \"Explainable AI Tutorial Census\",\n",
" \"tables\": [\n",
" {\n",
" \"name\": \"data\",\n",
" \"data\": df,\n",
" \"tabular_model_configuration\": {\"max_training_time\": 1},\n",
" }\n",
" ],\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a synthetic dataset\n",
"syn = mostly.generate(g).data()\n",
"print(f\"Created synthetic data with {syn.shape[0]:,} records and {syn.shape[1]:,} attributes\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OgvJ0XoWTHoX"
},
"source": [
"## Evaluate ML Performance on Synthetic\n",
"\n",
"This is also known as a Train-Real-Test-Synthetic (TRTS) approach."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from sklearn.metrics import roc_auc_score, accuracy_score\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rcParams[\"figure.dpi\"] = 72\n",
"\n",
"X_syn, y_syn = prepare_xy(syn)\n",
"p_syn = model.predict(X_syn)\n",
"auc = roc_auc_score(y_syn, p_syn)\n",
"acc = accuracy_score(y_syn, (p_syn >= 0.5).astype(int))\n",
"probs_df = pd.concat(\n",
" [\n",
" pd.Series(p_syn, name=\"probability\").reset_index(drop=True),\n",
" pd.Series(y_syn, name=target_col).reset_index(drop=True),\n",
" ],\n",
" axis=1,\n",
")\n",
"fig = sns.displot(data=probs_df, x=\"probability\", hue=target_col, bins=20, multiple=\"stack\")\n",
"fig = plt.title(f\"Accuracy: {acc:.1%}, AUC: {auc:.1%}\", fontsize=20)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OgvJ0XoWTHoX"
},
"source": [
"## Explain ML Model with Synthetic\n",
"\n",
"We will be using the **SHAP library**, a state-of-the-art Python package for Explainable AI. Learn more about SHAP and XAI at https://shap-lrjball.readthedocs.io/ and https://christophm.github.io/interpretable-ml-book/."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculate SHAP values for trained model w/ synthetic data\n",
"\n",
"Important: For this step no access to the original (privacy-sensitive) training data is needed anymore. We only need access to the trained model for inference, as well as to representative synthetic data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import shap\n",
"\n",
"explainer = shap.TreeExplainer(model)\n",
"shap_values = explainer.shap_values(X_syn)\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"### SHAP Feature Importance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"shap.initjs()\n",
"shap.summary_plot(shap_values, X_syn, plot_size=0.2, plot_type=\"bar\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SHAP Dependency Plots\n",
"\n",
"Let's study the value-dependent impact of each model feature. SHAP dependency plots are a great way to do so, as they not only show the average lift, but also the level of variance of that lift, at the same time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def plot_shap_dependency(col):\n",
" col_idx = [i for i in range(X_syn.shape[1]) if X_syn.columns[i] == col][0]\n",
" # Check if shap_values is 1-dimensional\n",
" if len(shap_values.shape) == 1:\n",
" shp_vals = pd.Series(shap_values, name=\"shap_value\")\n",
" else:\n",
" shp_vals = pd.Series(shap_values[:, col_idx], name=\"shap_value\")\n",
"\n",
" col_vals = X_syn.iloc[:, col_idx].reset_index(drop=True)\n",
" df = pd.concat([shp_vals, col_vals], axis=1)\n",
"\n",
" if col_vals.dtype.name != \"category\":\n",
" q01 = df[col].quantile(0.01)\n",
" q99 = df[col].quantile(0.99)\n",
" df = df.loc[(df[col] >= q01) & (df[col] <= q99), :]\n",
" else:\n",
" sorted_cats = list(df.groupby(col)[\"shap_value\"].mean().sort_values().index)\n",
" df[col] = df[col].cat.reorder_categories(sorted_cats, ordered=True)\n",
"\n",
" fig, ax = plt.subplots(figsize=(8, 4))\n",
" plt.ylim(-3.2, 3.2)\n",
" plt.title(col)\n",
" plt.xlabel(\"\")\n",
" if col_vals.dtype.name == \"category\":\n",
" plt.xticks(rotation=90)\n",
" ax.tick_params(axis=\"both\", which=\"major\", labelsize=8)\n",
" ax.tick_params(axis=\"both\", which=\"minor\", labelsize=6)\n",
" sns.lineplot(x=df[col], y=df[\"shap_value\"], color=\"black\").axhline(0, color=\"gray\", alpha=1, lw=0.5)\n",
" sns.scatterplot(x=df[col], y=df[\"shap_value\"], alpha=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_shap_dependency(\"age\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_shap_dependency(\"marital_status\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_shap_dependency(\"hours_per_week\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SHAP Values for Single (Synthetic) Samples\n",
"\n",
"Let's study the factors behind the model scores by inspecting individual samples. Note, that this level of reasoning at an individual-level would not be possible with real data, if that consists of privacy-sensitive information. However, synthetic data allows to reason around samples at any scale."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rl6-YXB_e0Ac",
"tags": []
},
"outputs": [],
"source": [
"def show_idx(i):\n",
" shap.initjs()\n",
" df = X_syn.iloc[i : i + 1, :]\n",
" df.insert(0, \"actual\", y_syn.iloc[i])\n",
" df.insert(1, \"score\", p_syn[i])\n",
" display(df)\n",
"\n",
" # Determine the correct expected value and shap values based on the structure of explainer.expected_value\n",
" if isinstance(explainer.expected_value, np.ndarray) and explainer.expected_value.ndim > 0:\n",
" # For multi-class classifiers\n",
" expected_value = explainer.expected_value[1]\n",
" shap_val = shap_values[1][i] if isinstance(shap_values[1], np.ndarray) else shap_values[1]\n",
" else:\n",
" # For binary classifiers or regressors\n",
" expected_value = explainer.expected_value\n",
" shap_val = shap_values[i]\n",
"\n",
" return shap.force_plot(expected_value, shap_val, X_syn.iloc[i, :], link=\"logit\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Random Sample"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"rnd_idx = X_syn.sample().index[0]\n",
"show_idx(rnd_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Sample with lowest score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"idx = np.argsort(p_syn)[0]\n",
"show_idx(idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Sample with highest score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"idx = np.argsort(p_syn)[-1]\n",
"show_idx(idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Sample with a median score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"idx = np.argsort(p_syn)[int(len(p_syn) / 2)]\n",
"show_idx(idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Sample a young Female Doctor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"idx = syn[(syn.education == \"Doctorate\") & (syn.sex == \"Female\") & (syn.age <= 30)].sample().index[0]\n",
"show_idx(idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Explore SHAP Values across a thousand samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"shap.initjs()\n",
"\n",
"# Determine the correct expected value\n",
"if isinstance(explainer.expected_value, np.ndarray):\n",
" expected_value = explainer.expected_value[1]\n",
"else:\n",
" expected_value = explainer.expected_value\n",
"\n",
"# Adjust how you access shap_values based on its structure\n",
"if isinstance(shap_values, list) and len(shap_values) > 1:\n",
" shap_vals_subset = shap_values[1][:1000, :]\n",
"else:\n",
" shap_vals_subset = shap_values[:1000, :]\n",
"\n",
"shap.force_plot(expected_value, shap_vals_subset, X.iloc[:1000, :], link=\"logit\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQMOiU1Bv6W4"
},
"source": [
"## Conclusion\n",
"\n",
"This tutorial demonstrated how ML models, that have been trained on real data, can be safely tested and explained with synthetic data. As the latter is not privacy-sensitive, this allows to have these types of introspections and validations performed by a significantly larger group of stakeholders. This is a key part to build safe & smart algorithms, that will have a significant impact on individuals' lives."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Further exercises\n",
"\n",
"In addition to walking through the above instructions, we suggest..\n",
"* replicating the explainability section with real data and compare results\n",
"* using a different dataset, eg. the UCI bank-marketing dataset [[2](#refs)]\n",
"* using a different ML model, eg. a RandomForest model [[3](#refs)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UW5ntiUB18yP"
},
"source": [
"## References\n",
"\n",
"1. https://archive.ics.uci.edu/ml/datasets/adult\n",
"1. https://archive.ics.uci.edu/ml/datasets/bank+marketing\n",
"1. https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
},
"vscode": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}