{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uUvsR-mWBoNS"
},
"source": [
"# Size vs. Accuracy Trade-Off
\n",
"\n",
"In this exercise, we are going to explore the relationship between the number of training samples, that are being used for the synthesis, and the corresponding accuracy of the generated synthetic data. We expect to see a higher accuracy for an increasing number of training samples. But along with a larger number of training samples, we will also see an increase in computational effort, i.e. overall runtime.\n",
"\n",
"
\n",
"\n",
"Note, that we shall not expect synthetic data to perfectly match the original data. This would only be satisfied by a copy of the data, which obviously would neither satisfy any privacy requirements nor would provide any novel samples. That being said, we shall expect that due to sampling variance the synthetic data can deviate. Ideally, just as much, and not more than the deviation that we would observe by analyzing an actual holdout data."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Synthesize Data via MOSTLY AI\n",
"\n",
"For this tutorial, we will be using the same UCI Adult Income [[1](#refs)] dataset, that was used in the Train-Synthetic-Test-Real tutorial. Thus, we have in total 48,842 records across 15 attributes, and will be using up to 39,073 (=80%) of those records for the creation of Generators.\n",
"\n",
"The following code creates different Generators, each time with a different number of maximum training samples. E.g. 100, 400, 1,600, 6,400, 25,600. Feel free to adjust these numbers as you are experimenting. Subsequently different Synthetic Datasets based on the Generators are created."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -U mostlyai # or: pip install -U 'mostlyai[local]'\n",
"%pip install scikit-learn seaborn lightgbm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"# fetch original data\n",
"df = pd.read_csv(\"https://github.com/mostly-ai/public-demo-data/raw/dev/census/census.csv.gz\")\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"# split into training and validation\n",
"df_trn, df_hol = train_test_split(df, test_size=0.2, random_state=1)\n",
"\n",
"print(f\"training data with {df_trn.shape[0]:,} records and {df_trn.shape[1]} attributes\")\n",
"print(f\"holdout data with {df_hol.shape[0]:,} records and {df_hol.shape[1]} attributes\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mostlyai.sdk import MostlyAI\n",
"\n",
"# initialize SDK\n",
"mostly = MostlyAI()\n",
"\n",
"# create Generators with different sample sizes\n",
"g_200 = mostly.train(data=df_trn.sample(200), name=\"census_200\")\n",
"g_400 = mostly.train(data=df_trn.sample(400), name=\"census_400\")\n",
"g_1600 = mostly.train(data=df_trn.sample(1600), name=\"census_1600\")\n",
"g_6400 = mostly.train(data=df_trn.sample(6400), name=\"census_6400\")\n",
"g_25600 = mostly.train(data=df_trn.sample(25600), name=\"census_25600\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate synthetic data\n",
"synthetic_data = {\n",
" \"syn_200\": mostly.probe(g_200, size=10_000),\n",
" \"syn_400\": mostly.probe(g_400, size=10_000),\n",
" \"syn_1600\": mostly.probe(g_1600, size=10_000),\n",
" \"syn_6400\": mostly.probe(g_6400, size=10_000),\n",
" \"syn_25600\": mostly.probe(g_25600, size=10_000),\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now go to the UI of MOSTLY AI, look at the created Generators and take notes of the reported runtime of each training step, and update the following DataFrame accordingly. The overall accuracy of the created Generators is loaded automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"results = pd.DataFrame(\n",
" [\n",
" {\"samples\": 200, \"accuracy\": g_200.accuracy, \"trainingtime\": 2},\n",
" {\"samples\": 400, \"accuracy\": g_400.accuracy, \"trainingtime\": 5},\n",
" {\"samples\": 1600, \"accuracy\": g_1600.accuracy, \"trainingtime\": 24},\n",
" {\"samples\": 6400, \"accuracy\": g_6400.accuracy, \"trainingtime\": 56},\n",
" {\"samples\": 25600, \"accuracy\": g_25600.accuracy, \"trainingtime\": 73},\n",
" ]\n",
")\n",
"results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sns.catplot(data=results, y=\"accuracy\", x=\"samples\", kind=\"point\", color=\"black\")\n",
"plt.xticks(rotation=45)\n",
"plt.xlabel(\"\")\n",
"plt.title(\"QA Report - Overall Accuracy\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Explore Synthetic Data\n",
"\n",
"Show 3 randomly sampled synthetic records for each of the datasets. Note, that you can execute the following cell multiple times, to see different samples. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"for generator, df in synthetic_data.items():\n",
" print(\"===\", generator, \"===\")\n",
" display(df.sample(n=3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Quality Assessment\n",
"\n",
"Concatenate all datasets together to ease comparions across these."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# combine synthetics\n",
"df = pd.concat([d.assign(split=k) for k, d in synthetic_data.items()], axis=0)\n",
"df[\"split\"] = pd.Categorical(df[\"split\"], categories=df[\"split\"].unique())\n",
"df.insert(0, \"split\", df.pop(\"split\"))\n",
"# combine synthetics and original\n",
"dataset = synthetic_data | {\"training\": df_trn, \"holdout\": df_hol}\n",
"df_all = pd.concat([d.assign(split=k) for k, d in dataset.items()], axis=0)\n",
"df_all[\"split\"] = pd.Categorical(df_all[\"split\"], categories=df_all[\"split\"].unique())\n",
"df_all.insert(0, \"split\", df_all.pop(\"split\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare single statistic"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"The more training samples have been used for the synthesis, the closer the synthetic distributions are expected to be to the original ones.\n",
"\n",
"Note, that we can also see deviations within statistics between the target and the holdout data. This is expected due to the sampling variance. The smaller the dataset, the larger the sampling variance will be."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Average number of Hours-Per-Week, split by Gender"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"stats = (\n",
" df_all.groupby([\"split\", \"sex\"], observed=True)[\"hours_per_week\"].mean().round(1).to_frame().reset_index(drop=False)\n",
")\n",
"stats = stats.pivot_table(index=\"split\", columns=[\"sex\"], observed=True).reset_index(drop=False)\n",
"stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Average Age, split by Marital Status"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"stats = (\n",
" df_all.groupby([\"split\", \"marital_status\"], observed=True)[\"age\"].mean().round().to_frame().reset_index(drop=False)\n",
")\n",
"stats = stats.loc[~stats[\"marital_status\"].isin([\"_RARE_\", \"Married-AF-spouse\", \"Married-spouse-absent\", \"Separated\"])]\n",
"stats = stats.pivot_table(index=\"split\", columns=\"marital_status\", values=\"age\", observed=True).reset_index()\n",
"stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Age distribution, split by Income"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sns.catplot(data=df_all, x=\"age\", y=\"split\", hue=\"income\", kind=\"violin\", split=True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check rule adherence\n",
"\n",
"The original data has a 1:1 relationship between `education` and `education_num`. Let's check in how many cases the generated synthetic data has correctly retained that specific rule between these two columns."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# display unique combinations of `education` and `education_num`\n",
"df_trn[[\"education\", \"education_num\"]].drop_duplicates().sort_values(\"education_num\").reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Convert `education` to Categorical with proper sort order\n",
"df[\"education\"] = pd.Categorical(df[\"education\"], categories=df_trn.sort_values(\"education_num\")[\"education\"].unique())\n",
"\n",
"# Calculate the correct match, explicitly excluding the group keys from the apply operation\n",
"stats = (\n",
" df.groupby(\"split\", observed=True)\n",
" .apply(lambda x: ((x[\"education\"].cat.codes + 1) == x[\"education_num\"]).mean())\n",
" .reset_index(name=\"matches\")\n",
")\n",
"\n",
"stats"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sns.catplot(data=stats, y=\"matches\", x=\"split\", kind=\"point\", color=\"black\")\n",
"plt.xticks(rotation=45)\n",
"plt.xlabel(\"\")\n",
"plt.title(\"Share of Matches\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OgvJ0XoWTHoX"
},
"source": [
"### Compare ML performance via TSTR\n",
"\n",
"Let's perform a Train-Synthetic-Test-Real evaluation via a downstream LightGBM classifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rl6-YXB_e0Ac"
},
"outputs": [],
"source": [
"import lightgbm as lgb\n",
"from lightgbm import early_stopping\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"target_col = \"income\"\n",
"target_val = \">50K\"\n",
"\n",
"\n",
"# prepare data, and split into features `X` and target `y`\n",
"def prepare_xy(df: pd.DataFrame):\n",
" y = (df[target_col] == target_val).astype(int)\n",
" str_cols = [col for col in df.select_dtypes([\"object\", \"string\"]).columns if col != target_col]\n",
" for col in str_cols:\n",
" df[col] = pd.Categorical(df[col])\n",
" cat_cols = [col for col in df.select_dtypes(\"category\").columns if col != target_col]\n",
" num_cols = [col for col in df.select_dtypes(\"number\").columns if col != target_col]\n",
" for col in num_cols:\n",
" df[col] = df[col].astype(\"float\")\n",
" X = df[cat_cols + num_cols]\n",
" return X, y\n",
"\n",
"\n",
"# train ML model with early stopping\n",
"def train_model(X, y):\n",
" cat_cols = list(X.select_dtypes(\"category\").columns)\n",
" X_trn, X_val, y_trn, y_val = train_test_split(X, y, test_size=0.2, random_state=1)\n",
" ds_trn = lgb.Dataset(X_trn, label=y_trn, categorical_feature=cat_cols, free_raw_data=False)\n",
" ds_val = lgb.Dataset(X_val, label=y_val, categorical_feature=cat_cols, free_raw_data=False)\n",
" model = lgb.train(\n",
" params={\"verbose\": -1, \"metric\": \"auc\", \"objective\": \"binary\"},\n",
" train_set=ds_trn,\n",
" valid_sets=[ds_val],\n",
" callbacks=[early_stopping(5)],\n",
" )\n",
" return model\n",
"\n",
"\n",
"# apply ML Model to some holdout data, report key metrics, and visualize scores\n",
"def evaluate_model(model, hol):\n",
" X_hol, y_hol = prepare_xy(hol)\n",
" probs = model.predict(X_hol)\n",
" auc = roc_auc_score(y_hol, probs)\n",
" return auc\n",
"\n",
"\n",
"def train_and_evaluate(df):\n",
" X, y = prepare_xy(df)\n",
" model = train_model(X, y)\n",
" auc = evaluate_model(model, df_hol)\n",
" return auc\n",
"\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZtQiGZdyCB72",
"outputId": "fb0fa7ad-c3f3-4d8a-d907-a7248e770e6d"
},
"outputs": [],
"source": [
"aucs = {k: train_and_evaluate(df) for k, df in synthetic_data.items()}\n",
"aucs = pd.Series(aucs).round(3).to_frame(\"auc\").reset_index()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "14CMIhcvgQ77",
"outputId": "f8f46533-b092-4ec7-8ecd-c2f5c51961ec"
},
"outputs": [],
"source": [
"sns.catplot(data=aucs, y=\"auc\", x=\"index\", kind=\"point\", color=\"black\")\n",
"plt.xticks(rotation=45)\n",
"plt.xlabel(\"\")\n",
"plt.title(\"Predictive Performance (AUC) on Holdout\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UMGNussThvys"
},
"source": [
"## Conclusion\n",
"\n",
"For the given dataset and the given synthesizer we can indeed observe an increase in synthetic data quality with a growing number of training samples. This can be measured with respect to accuracy, as well as ML utility.\n",
"\n",
"As we can also observe, is that a holdout dataset will exhibit deviations from the training data due to the sampling noise as well. With the holdout data being actual data, that hasn't been seen before, it serves us as a north star in terms of maximum acchievable accuracy for synthetic data. See our paper on this subject [[2](#refs)].\n",
"\n",
"## Further exercises\n",
"\n",
"In addition to walking through the above instructions, we suggest..\n",
"* to limit model training to a few epochs, e.g. by setting the maximum number of epochs to 1 or 5 and study its impact on runtime and quality.\n",
"* to synthesize with different model_sizes: Small, Medium and Large, and study its impact on runtime and quality.\n",
"* to synthesize with the same settings several times, and with that study the variability in quality across several runs.\n",
"* to calculate and compare your own statistics, and then compare the deviations between synthetic and training. The deviations between holdout and training can serve as a benchmark ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"\n",
"1. https://archive.ics.uci.edu/ml/datasets/adult\n",
"1. https://www.frontiersin.org/articles/10.3389/fdata.2021.679939/full"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}