{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "uUvsR-mWBoNS" }, "source": [ "# Synthetic Text Generation using a fast LSTM model trained from scratch \"Run\n", "\n", "In this notebook, we demonstrate how to synthesize free text columns, and will furthermore explore its quality.\n", "\n", "Note, that the usage of a GPU, with 24GB or more, is strongly recommended for running this tutorial.\n", "\n", "For further background see also [this blog post](https://mostly.ai/blog/synthetic-data-for-text-annotation/) on \"How To Scale Up Your Text Annotation Initiatives with Synthetic Text\". We will be using a trimmed down version of a dataset containing AirBnB listings in London. This dataset can be downloaded in our public data repository [here](https://github.com/mostly-ai/public-demo-data/raw/dev/airbnb/london.csv.gz). The original data was downloaded from [Inside AirBnB](http://insideairbnb.com/get-the-data)." ] }, { "cell_type": "markdown", "metadata": { "id": "MKRa93uuSqZS" }, "source": [ "## Synthesize Data via MOSTLY AI" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": "%pip install -U mostlyai # or: pip install 'mostlyai[local-gpu]'" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# fetch original data\n", "tgt = pd.read_csv(\"https://github.com/mostly-ai/public-demo-data/raw/dev/airbnb/london.csv.gz\")\n", "tgt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 89 }, "id": "W4IDNOIyPW7L", "outputId": "413106b9-de23-441b-ae27-f6efe45085a6", "tags": [] }, "outputs": [], "source": [ "from mostlyai.sdk import MostlyAI\n", "\n", "# initialize SDK\n", "mostly = MostlyAI()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train a generator on the pre-processed AirBnB data\n", "config = {\n", " \"name\": \"Synthetic Text Tutorial AirBnB\",\n", " \"tables\": [\n", " {\n", " \"name\": \"airBnB\",\n", " \"data\": tgt,\n", " \"tabular_model_configuration\": {\n", " \"max_training_time\": 1, # the tabular model should anyways finish in less than 1 min\n", " },\n", " \"language_model_configuration\": {\n", " \"max_training_time\": 20, # we recommend at least ~20 minutes if training on an A10G or similar GPU\n", " },\n", " \"columns\": [\n", " {\"name\": \"host_name\", \"model_encoding_type\": \"LANGUAGE_TEXT\"},\n", " {\"name\": \"title\", \"model_encoding_type\": \"LANGUAGE_TEXT\"},\n", " {\"name\": \"property_type\", \"model_encoding_type\": \"TABULAR_CATEGORICAL\"},\n", " {\"name\": \"room_type\", \"model_encoding_type\": \"TABULAR_CATEGORICAL\"},\n", " {\"name\": \"neighbourhood\", \"model_encoding_type\": \"TABULAR_CATEGORICAL\"},\n", " {\"name\": \"price\", \"model_encoding_type\": \"TABULAR_NUMERIC_AUTO\"},\n", " ],\n", " }\n", " ],\n", "}\n", "\n", "g_airbnb = mostly.train(config=config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# generate a synthetic dataset\n", "syn = mostly.generate(generator=g_airbnb, size=1_000).data()\n", "print(f\"Created synthetic data with {syn.shape[0]:,} records and {syn.shape[1]:,} attributes\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OgvJ0XoWTHoX" }, "source": [ "## Explore Synthetic Text\n", "\n", "Show 10 randomly sampled synthetic records. Note, that you can execute the following cell multiple times, to see different samples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "syn.sample(n=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare this to 10 randomly sampled original records." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "tgt.sample(n=10)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Inspect Character Set\n", "\n", "You will note, that the character set of the synthetic data is shorter. This is due to the privacy mechanism within the MOSTLY AI platform, where very rare tokens are being removed, to prevent that their presence give away information on the existence of individual records." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Concatenate 'title' strings and remove duplicates by converting to a set, then back to a list\n", "tgt_chars = \"\".join(sorted(list(set(tgt[\"title\"].str.cat(sep=\" \")))))\n", "syn_chars = \"\".join(sorted(list(set(syn[\"title\"].str.cat(sep=\" \")))))\n", "\n", "# Display the concatenated strings and their lengths\n", "print(\"## ORIGINAL ##\\n\", tgt_chars, \"\\n\")\n", "print(\"Length of ORIGINAL characters:\", len(tgt_chars), \"\\n\")\n", "print(\"## SYNTHETIC ##\\n\", syn_chars, \"\\n\")\n", "print(\"Length of SYNTHETIC characters:\", len(syn_chars), \"\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inspect Character Frequency" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Get character frequencies for 'tgt'\n", "tgt_chars = tgt[\"title\"].str.split(\"\").explode()\n", "tgt_freq = tgt_chars.value_counts(normalize=True).rename_axis(\"char\").reset_index(name=\"tgt\")\n", "\n", "# Get character frequencies for 'syn'\n", "syn_chars = syn[\"title\"].str.split(\"\").explode()\n", "syn_freq = syn_chars.value_counts(normalize=True).rename_axis(\"char\").reset_index(name=\"syn\")\n", "\n", "# Merge the frequencies and sort\n", "title_char_freq = pd.merge(tgt_freq, syn_freq, on=\"char\", how=\"outer\").round(5)\n", "title_char_freq.sort_values(by=\"tgt\", ascending=False, inplace=True)\n", "\n", "# Display the frequencies\n", "title_char_freq.head(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Set 'char' column as the index\n", "title_char_freq_indexed = title_char_freq.set_index(\"char\")\n", "\n", "# Plot the first 50 characters using the new index\n", "ax = title_char_freq_indexed.head(50).plot.line()\n", "plt.title(\"Distribution of 50 most common characters\")\n", "\n", "# Set x-axis labels with no rotation for better readability\n", "plt.xticks(\n", " ticks=range(len(title_char_freq_indexed.head(50))), labels=title_char_freq_indexed.head(50).index, rotation=0\n", ")\n", "\n", "plt.xlabel(\"Character\")\n", "plt.ylabel(\"Frequency\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that Character Frequencies are perfectly retained." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Inspect Term Frequency" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import re\n", "\n", "\n", "def sanitize(s):\n", " s = str(s).lower()\n", " s = re.sub('[\\\\,\\\\.\\\\)\\\\(\\\\!\\\\\"\\\\:\\\\/]', \" \", s)\n", " s = re.sub(\"[ ]+\", \" \", s)\n", " return s\n", "\n", "\n", "# Apply the sanitize function and split the titles into terms\n", "tgt[\"terms\"] = tgt[\"title\"].apply(lambda x: sanitize(x)).str.split(\" \")\n", "syn[\"terms\"] = syn[\"title\"].apply(lambda x: sanitize(x)).str.split(\" \")\n", "\n", "# Explode 'terms' and create a DataFrame with explicit column names before applying value_counts\n", "tgt_terms_df = tgt[\"terms\"].explode().to_frame(name=\"term\")\n", "syn_terms_df = syn[\"terms\"].explode().to_frame(name=\"term\")\n", "\n", "# Calculate the normalized value counts and reset the index\n", "tgt_freq = tgt_terms_df[\"term\"].value_counts(normalize=True).reset_index(name=\"tgt\").rename(columns={\"index\": \"term\"})\n", "syn_freq = syn_terms_df[\"term\"].value_counts(normalize=True).reset_index(name=\"syn\").rename(columns={\"index\": \"term\"})\n", "\n", "# Merge the frequencies and sort by 'tgt' in descending order\n", "title_term_freq = pd.merge(tgt_freq, syn_freq, on=\"term\", how=\"outer\").round(5)\n", "title_term_freq = title_term_freq.sort_values(by=\"tgt\", ascending=False)\n", "\n", "# Display the top and bottom rows\n", "display(title_term_freq.head(10))\n", "display(title_term_freq.head(200).tail(10))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Set 'term' column as the index\n", "title_term_freq_indexed = title_term_freq.set_index(\"term\")\n", "\n", "# Plot the first 50 terms using the new index\n", "ax = title_term_freq_indexed.head(50).plot.line()\n", "plt.title(\"Distribution of 50 most common terms\")\n", "\n", "# Set x-axis labels with a 90-degree rotation for better readability\n", "plt.xticks(\n", " ticks=range(len(title_term_freq_indexed.head(50))), labels=title_term_freq_indexed.head(50).index, rotation=90\n", ")\n", "\n", "plt.xlabel(\"Term\")\n", "plt.ylabel(\"Frequency\")\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that Term Frequencies are perfectly retained." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Inspect Term Co-occurrence" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "def calc_conditional_probability(term1, term2):\n", " # Ensure no NaN values in 'title' before applying str.contains\n", " tgt_beds = tgt[\"title\"].fillna(\"\").str.lower().str.contains(term1)\n", " syn_beds = syn[\"title\"].fillna(\"\").str.lower().str.contains(term1)\n", "\n", " # Use the boolean Series to filter 'title' containing term1 and then check for term2\n", " tgt_beds_double = tgt[\"title\"][tgt_beds].str.lower().str.contains(term2).mean()\n", " syn_beds_double = syn[\"title\"][syn_beds].str.lower().str.contains(term2).mean()\n", "\n", " print(f\"{tgt_beds_double:.0%} of original Listings, that contain `{term1}`, also contain `{term2}`\")\n", " print(f\"{syn_beds_double:.0%} of synthetic Listings, that contain `{term1}`, also contain `{term2}`\")\n", " print(\"\")\n", "\n", "\n", "calc_conditional_probability(\"bed\", \"double\")\n", "calc_conditional_probability(\"bed\", \"king\")\n", "calc_conditional_probability(\"heart\", \"london\")\n", "calc_conditional_probability(\"london\", \"heart\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that Term Co-occurrences are almost perfectly retained.\n", "\n", "Now you might be asking yourself: if all of these characteristics are maintained, what are the chances that we'll end up with exact matches, i.e. synthetic records with the exact same `title` value as a record in the original dataset? Or even a synthetic record with the exact same values for all the columns?\n", "\n", "Let's start by trying to find an exact match for 1 specific synthetic `title` value:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# find exact match for 1 specific synthetic title value. Copy a `title` value from a synthetic record into the `title_value` field below and run the cell to find an exact match in the original dataset\n", "title_value = \"Airy large double room\"\n", "tgt.loc[tgt[\"title\"].str.contains(title_value, case=False, na=False)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depending on your chosen value, you may or may not find an exact match. This row-by-row validation process doesn't indicate very much and, more importantly, doesn't scale very well to the 71K rows in the dataset." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Inspect Privacy via Exact Matches\n", "\n", "Let's perform a simplified check for privacy, by looking for exact matches between the synthetic and the original.\n", "\n", "For that we first split the original data into two equally-sized sets, and measure the number of matches between those two sets." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "n = int(tgt.shape[0] / 2)\n", "pd.merge(tgt[[\"title\"]][:n].drop_duplicates(), tgt[[\"title\"]][n:].drop_duplicates())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we take an equally-sized subset of the synthetic data, and again measure the number of matches between that set and the original data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "pd.merge(tgt[[\"title\"]][:n].drop_duplicates(), syn[[\"title\"]][:n].drop_duplicates())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that exact matches between original and synthetic data can occur. However, they occur only for the most commonly used descriptions, and they do not occur more often than they occur in the original data itself.\n", "\n", "Thus, it's important to note, that matchinig values or matching complete records are by themselves not a sign of privacy leak. They are only an issue if they occur more frequently than we would expect based on the original dataset. Also note that removing those exact matches via post-processing would have a detrimental contrary effect. The absence of a value like \"Lovely single room\" in a sufficiently large synthetic text corpus would in this case actually give away the fact that this sentence was present in the original. See [[1](#refs)] respectively [[2](#refs)] for more background info on this aspect." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze Price vs. Text correlation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "tgt_term_price = tgt[[\"terms\", \"price\"]].explode(column=\"terms\").groupby(\"terms\")[\"price\"].median()\n", "syn_term_price = syn[[\"terms\", \"price\"]].explode(column=\"terms\").groupby(\"terms\")[\"price\"].median()\n", "\n", "\n", "def print_term_price(term):\n", " print(f\"Median Price of original Listings, that contain `{term}`: ${tgt_term_price[term]:.0f}\")\n", " print(f\"Median Price of synthetic Listings, that contain `{term}`: ${syn_term_price[term]:.0f}\")\n", " print(\"\")\n", "\n", "\n", "print_term_price(\"luxury\")\n", "print_term_price(\"stylish\")\n", "print_term_price(\"cozy\")\n", "print_term_price(\"small\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that correlations between Term occurence and the price per night, are also very well retained." ] }, { "cell_type": "markdown", "metadata": { "id": "eQMOiU1Bv6W4" }, "source": [ "## Conclusion\n", "\n", "This tutorial demonstrated how synthetic text can be generated wihtin the context of an otherwise structured dataset. We analyzed the generated texts, and validated that characters and terms occur with the same frequency, while exact matches do not occur anymore likely than within the actual text itself.\n", "\n", "This feature thus allows to retain valuable statistical insights, typically burried away in free text columns, that remain inaccessible due to their privacy sensitive nature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further exercises\n", "\n", "In addition to walking through the above instructions, we suggest..\n", "* analyzing further correlations, also for `host_name`\n", "* using a different generation mood, eg. conservative sampling\n", "* using a different dataset, eg. the Austrian First Name [[3](#refs)]" ] }, { "cell_type": "markdown", "metadata": { "id": "UW5ntiUB18yP" }, "source": [ "## References\n", "\n", "1. https://github.com/mostly-ai/public-demo-data/blob/dev/firstnames_at/firstnames_at.csv.gz\n", "1. https://www.frontiersin.org/articles/10.3389/fdata.2021.679939/full\n", "1. https://mostly.ai/blog/truly-anonymous-synthetic-data-legal-definitions-part-ii/" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 4 }