{ "cells": [ { "cell_type": "markdown", "id": "3fd5150a-d33e-435b-ac47-3fdf15fb05c2", "metadata": {}, "source": [ "# Create Fair Synthetic Data \"Run\n", "\n", "Fairness in machine learning aims to ensure that algorithms and models treat individuals and groups equitably, without introducing or perpetuating bias. The objective is to prevent discrimination and address societal inequities, particularly concerning protected attributes such as race, gender, age, or ethnicity.\n", "\n", "In this tutorial, we showcase how MOSTLY AI’s Fairness feature can help bridge fairness gaps in your data. By generating a fair synthetic dataset, downstream models trained on this data are empowered to produce fair and unbiased predictions.\n", "\n", "For further background see also [this paper](https://arxiv.org/abs/2311.03000) on \"_Strong statistical parity through fair synthetic data_\"." ] }, { "cell_type": "markdown", "id": "27fadb9e-433c-43bf-b24a-e8cf1d0f5843", "metadata": {}, "source": [ "## Data Preparation\n", "\n", "Let's use the UCI Adult [1] dataset, consisting of 48,842 records across 14 attributes. There we can observe ~30% of men having a high income compared to only ~11% of women, resulting in a statistical parity difference of 0.2." ] }, { "cell_type": "code", "execution_count": null, "id": "688d1723-ec2f-4a1e-89b1-936a31dc1ebc", "metadata": {}, "outputs": [], "source": [ "%pip install -U mostlyai # or: pip install -U 'mostlyai[local]'\n", "%pip install matplotlib plotly scikit-learn lightgbm" ] }, { "cell_type": "code", "execution_count": null, "id": "64397119-e477-44a8-8c7a-500e09560aea", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "\n", "# fetch original data\n", "df = pd.read_csv(\"https://github.com/mostly-ai/public-demo-data/raw/dev/census/census.csv.gz\")\n", "df" ] }, { "cell_type": "code", "execution_count": null, "id": "4976ea54-2018-4a6e-8993-60b7f19efcb5", "metadata": {}, "outputs": [], "source": [ "# split 80/20 into training and holdout\n", "trn, hol = train_test_split(df, test_size=0.2, random_state=42)\n", "trn.reset_index(drop=True, inplace=True)\n", "hol.reset_index(drop=True, inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "99b9a88c-aba7-4bdc-99c6-449ed053c985", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "\n", "def plot_income_by_gender(df, title):\n", " # Create a bar plot for the distribution of income for males and females\n", " plt.figure(figsize=(10, 6))\n", " income_gender_distribution = df.groupby([\"sex\", \"income\"]).size().unstack()\n", " income_gender_proportions = (\n", " income_gender_distribution.div(income_gender_distribution.sum(axis=1), axis=0) * 100\n", " ) # Convert to percentages\n", "\n", " # Customizing the plot\n", " ax = income_gender_proportions.plot(kind=\"bar\", stacked=True, ax=plt.gca())\n", " ax.spines[\"top\"].set_visible(False)\n", " ax.spines[\"right\"].set_visible(False)\n", "\n", " # Adding title and labels\n", " plt.title(f\"Distribution of Income by Gender - {title}\", fontsize=16, weight=\"bold\")\n", " plt.xlabel(\"Gender\", fontsize=14)\n", " plt.ylabel(\"Share (%)\", fontsize=14)\n", " plt.xticks(rotation=0, fontsize=12)\n", " plt.yticks(fontsize=12)\n", " plt.legend(title=\"Income\", fontsize=12, title_fontsize=14, loc=\"upper right\")\n", "\n", " # Adding data labels\n", " for bar_group in ax.containers:\n", " ax.bar_label(bar_group, fmt=\"%.1f%%\", label_type=\"center\", fontsize=10)\n", "\n", " plt.show()\n", "\n", "\n", "plot_income_by_gender(df, title=\"Original\")" ] }, { "cell_type": "markdown", "id": "2ee9726e-4eb8-4176-9ca4-aa8df2cb2933", "metadata": {}, "source": [ "## Synthesize Data via MOSTLY AI\n", "\n", "The code below will automatically create a Generator using the MOSTLY AI Synthetic Data SDK. Then we will use that Generator to create both, Synthetic dataset and Fair Synthetic dataset with turned on Fairness feature for the target `income` column and sensitive `sex` column." ] }, { "cell_type": "code", "execution_count": null, "id": "125e6d793c1175cb", "metadata": {}, "outputs": [], "source": [ "from mostlyai.sdk import MostlyAI\n", "\n", "# initialize SDK\n", "mostly = MostlyAI()" ] }, { "cell_type": "code", "execution_count": null, "id": "aa76b44b-746a-4905-98c7-60a5f69bf713", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# train a generator on the original training data\n", "g = mostly.train(data=trn, name=\"Fairness Tutorial\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0c7fcd0f", "metadata": {}, "outputs": [], "source": [ "# create a representative synthetic data that preserve the bias present in the original data\n", "sd = mostly.generate(g, name=\"Fairness Tutorial - Representative Synthetic Data\")\n", "syn = sd.data()" ] }, { "cell_type": "code", "execution_count": null, "id": "99baee6f076612c", "metadata": {}, "outputs": [], "source": [ "# Define the fairness configuration\n", "fairness_config = {\n", " \"name\": \"Fairness Tutorial - Fair Synthetic Data\",\n", " \"tables\": [\n", " {\n", " \"name\": \"data\",\n", " \"configuration\": {\n", " \"fairness\": {\n", " \"target_column\": \"income\", # define fairness target\n", " \"sensitive_columns\": [\"sex\"], # define sensitive columns\n", " }\n", " },\n", " }\n", " ],\n", "}\n", "\n", "# create fair synthetic data with mitigated bias\n", "fair_sd = mostly.generate(g, config=fairness_config)\n", "fair_syn = fair_sd.data()" ] }, { "cell_type": "markdown", "id": "97058dc7-b791-4674-b028-9599384b9d7f", "metadata": {}, "source": [ "You can now examine the distributions using the Model QA and Data QA reports. These reports can be downloaded via `sd.reports()` for synthetic data and `fair_sd.reports()` for fair synthetic data. The Model QA report evaluates the accuracy and privacy performance of the trained generative AI model, demonstrating that the distributions are faithfully learned, including the original proportions of high-income men and women. The Data QA report visualizes how the income distributions in the delivered Fair Synthetic dataset have been adjusted to mitigate statistical parity differences, ensuring fairness" ] }, { "cell_type": "code", "execution_count": null, "id": "ec3d89cb", "metadata": {}, "outputs": [], "source": [ "print(sd.reports(\"representative-synthetic-data-reports.zip\").absolute())\n", "print(fair_sd.reports(\"fair-synthetic-data-reports.zip\").absolute())" ] }, { "cell_type": "code", "execution_count": null, "id": "a675655b", "metadata": {}, "outputs": [], "source": [ "plot_income_by_gender(syn, \"Representative Synthetic Data\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bc4716bd", "metadata": {}, "outputs": [], "source": [ "plot_income_by_gender(fair_syn, \"Fair Synthetic Data\")" ] }, { "cell_type": "markdown", "id": "10381448", "metadata": {}, "source": [ "Statistical parity difference is mitigated for the fair synthetic dataset, i.e. the proportion of females and mals with high income is comparable." ] }, { "cell_type": "markdown", "id": "25c4ac9f-8eef-4706-b033-6ea2d4e1d941", "metadata": {}, "source": [ "## Train a Downstream ML Model\n", "\n", "We can compare the model prediction using downstream prediction model on the original, synthetic data and fair synthetic data." ] }, { "cell_type": "code", "execution_count": null, "id": "0d03ace1", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.compose import ColumnTransformer\n", "from lightgbm import LGBMClassifier\n", "from sklearn.metrics import roc_auc_score, f1_score, accuracy_score" ] }, { "cell_type": "code", "execution_count": null, "id": "c4a24fbd-c32d-4995-b9fa-c9127430c96d", "metadata": {}, "outputs": [], "source": [ "trn_fns = [\"original\", \"synthetic\", \"fair synthetic\"]\n", "\n", "y_hol = (hol.pop(\"income\") == \">50K\").astype(int)\n", "X_hol = hol\n", "\n", "cat_cols = hol.columns[X_hol.dtypes == \"object\"].tolist()\n", "num_cols = hol.columns[X_hol.dtypes == \"int64\"].tolist()\n", "ct = ColumnTransformer([(\"c\", OneHotEncoder(handle_unknown=\"ignore\"), cat_cols), (\"n\", MinMaxScaler(), num_cols)])\n", "model = LGBMClassifier(n_estimators=100)\n", "pipe = Pipeline(steps=[(\"t\", ct), (\"m\", model)])\n", "\n", "res = []\n", "predicted_probs = pd.DataFrame()\n", "for trn_fn, trn in zip(trn_fns, [trn, syn, fair_syn]):\n", " y_trn = (trn.pop(\"income\") == \">50K\").astype(int)\n", " X_trn = trn\n", " pipe.fit(X_trn, y_trn)\n", " probs = pipe.predict_proba(X_hol)[:, 1]\n", " predicted_probs[trn_fn] = probs\n", " res.append(\n", " {\n", " \"AUC\": roc_auc_score(y_hol, probs),\n", " \"Accuracy\": accuracy_score(y_hol, probs > 0.5),\n", " \"F1\": f1_score(y_hol, probs > 0.5, average=\"macro\"),\n", " \"N\": trn.shape[0],\n", " \"fn\": trn_fn,\n", " }\n", " )\n", "\n", "predicted_probs[\"sex\"] = hol[\"sex\"]\n", "predicted_probs[\"income\"] = y_hol" ] }, { "cell_type": "code", "execution_count": null, "id": "087fecf3-276c-4bff-aff7-ef04ead0dd80", "metadata": {}, "outputs": [], "source": [ "# sort the results based on the model performance:\n", "res_sort = pd.DataFrame(res, index=list(range(len(res))))\n", "predicted_probs[\"sex\"] = hol[\"sex\"]\n", "predicted_probs[\"income\"] = y_hol\n", "res_sort[\"SP mean difference\"] = (res_sort[\"fn\"]).map(\n", " predicted_probs.groupby([\"sex\"])[trn_fns].mean().diff().iloc[1, :]\n", ")\n", "res_sort.sort_values(by=\"SP mean difference\", ascending=True)\n", "sorting = res_sort[\"fn\"]\n", "res_sort" ] }, { "cell_type": "markdown", "id": "06fa5c97", "metadata": {}, "source": [ "The model performance on synthetic data is comparable to that on the original data, with a similar statistical parity (SP) difference. While fair synthetic data successfully resolves the SP difference, it does so at the expense of downstream model performance, reflected in a decreased AUC." ] }, { "cell_type": "code", "execution_count": null, "id": "2e1baaa3", "metadata": {}, "outputs": [], "source": [ "import plotly.express as px" ] }, { "cell_type": "code", "execution_count": null, "id": "d5e46bb2", "metadata": {}, "outputs": [], "source": [ "px.histogram(predicted_probs, x=\"original\", color=\"sex\", marginal=\"box\", title=\"Prediction_original\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d53309dc", "metadata": {}, "outputs": [], "source": [ "px.histogram(predicted_probs, x=\"synthetic\", color=\"sex\", marginal=\"box\", title=\"Prediction_synthetic\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d24db97b", "metadata": {}, "outputs": [], "source": [ "px.histogram(predicted_probs, x=\"fair synthetic\", color=\"sex\", marginal=\"box\", title=\"Prediction_fair_synthetic\")" ] }, { "cell_type": "markdown", "id": "660ef925", "metadata": {}, "source": [ "To evaluate the downstream model's predictions, we analyze the distribution of the prediction probabilities. To get fair predictions at any chosen classification threshold, it is crucial that the prediction distributions for males and females are comparable. This is best assessed using box plots. In the original data (and consequently in the synthetic data), the probability distribution for females is shifted to the left, indicating that the model predicts high income with lower probability for females compared to males. However, when the predictor is trained on the fair synthetic data, the distributions for males and females become more aligned, indicating improved fairness in predictions." ] }, { "cell_type": "markdown", "id": "804b7f3a-b107-4dac-8a19-fd2854d1ba88", "metadata": {}, "source": [ "## Conclusion\n", "\n", "As we can see, the fair synthetic data mitigate the sex bias present in the original data. Moreover, the downstream model trained on fair synthetic data, specifically with respect to statistical parity, produces fair predictions even when inferring from real-world, biased data.\n", "\n", "## Further Reading\n", "\n", "* For a demo within the MOSTLY AI platform, please see https://www.youtube.com/watch?v=Uxq_1t2_NCk\n", "* For theoretical background and further analysis, see https://arxiv.org/abs/2311.03000" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }