{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Embedding Wikipedia articles for search\n", "\n", "This notebook shows how we prepared a dataset of Wikipedia articles for search, used in [Question_answering_using_embeddings.ipynb](Question_answering_using_embeddings.ipynb).\n", "\n", "Procedure:\n", "\n", "0. Prerequisites: Import libraries, set API key (if needed)\n", "1. Collect: We download a few hundred Wikipedia articles about the 2022 Olympics\n", "2. Chunk: Documents are split into short, semi-self-contained sections to be embedded\n", "3. Embed: Each section is embedded with the OpenAI API\n", "4. Store: Embeddings are saved in a CSV file (for large datasets, use a vector database)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Prerequisites\n", "\n", "### Import libraries" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import mwclient # for downloading example Wikipedia articles\n", "import mwparserfromhell # for splitting Wikipedia articles into sections\n", "from openai import OpenAI # for generating embeddings\n", "import os # for environment variables\n", "import pandas as pd # for DataFrames to store article sections and embeddings\n", "import re # for cutting links out of Wikipedia articles\n", "import tiktoken # for counting tokens\n", "\n", "client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"\"))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Install any missing libraries with `pip install` in your terminal. E.g.,\n", "\n", "```zsh\n", "pip install openai\n", "```\n", "\n", "(You can also do this in a notebook cell with `!pip install openai`.)\n", "\n", "If you install any libraries, be sure to restart the notebook kernel." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Set API key (if needed)\n", "\n", "Note that the OpenAI library will try to read your API key from the `OPENAI_API_KEY` environment variable. If you haven't already, set this environment variable by following [these instructions](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Collect documents\n", "\n", "In this example, we'll download a few hundred Wikipedia articles related to the 2022 Winter Olympics." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 179 article titles in Category:2022 Winter Olympics.\n" ] } ], "source": [ "# get Wikipedia pages about the 2022 Winter Olympics\n", "\n", "CATEGORY_TITLE = \"Category:2022 Winter Olympics\"\n", "WIKI_SITE = \"en.wikipedia.org\"\n", "\n", "\n", "def titles_from_category(\n", " category: mwclient.listing.Category, max_depth: int\n", ") -> set[str]:\n", " \"\"\"Return a set of page titles in a given Wiki category and its subcategories.\"\"\"\n", " titles = set()\n", " for cm in category.members():\n", " if type(cm) == mwclient.page.Page:\n", " # ^type() used instead of isinstance() to catch match w/ no inheritance\n", " titles.add(cm.name)\n", " elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:\n", " deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)\n", " titles.update(deeper_titles)\n", " return titles\n", "\n", "\n", "site = mwclient.Site(WIKI_SITE)\n", "category_page = site.pages[CATEGORY_TITLE]\n", "titles = titles_from_category(category_page, max_depth=1)\n", "# ^note: max_depth=1 means we go one level deep in the category tree\n", "print(f\"Found {len(titles)} article titles in {CATEGORY_TITLE}.\")\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Chunk documents\n", "\n", "Now that we have our reference documents, we need to prepare them for search.\n", "\n", "Because GPT can only read a limited amount of text at once, we'll split each document into chunks short enough to be read.\n", "\n", "For this specific example on Wikipedia articles, we'll:\n", "- Discard less relevant-looking sections like External Links and Footnotes\n", "- Clean up the text by removing reference tags (e.g., ), whitespace, and super short sections\n", "- Split each article into sections\n", "- Prepend titles and subtitles to each section's text, to help GPT understand the context\n", "- If a section is long (say, > 1,600 tokens), we'll recursively split it into smaller sections, trying to split along semantic boundaries like paragraphs" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# define functions to split Wikipedia pages into sections\n", "\n", "SECTIONS_TO_IGNORE = [\n", " \"See also\",\n", " \"References\",\n", " \"External links\",\n", " \"Further reading\",\n", " \"Footnotes\",\n", " \"Bibliography\",\n", " \"Sources\",\n", " \"Citations\",\n", " \"Literature\",\n", " \"Footnotes\",\n", " \"Notes and references\",\n", " \"Photo gallery\",\n", " \"Works cited\",\n", " \"Photos\",\n", " \"Gallery\",\n", " \"Notes\",\n", " \"References and sources\",\n", " \"References and notes\",\n", "]\n", "\n", "\n", "def all_subsections_from_section(\n", " section: mwparserfromhell.wikicode.Wikicode,\n", " parent_titles: list[str],\n", " sections_to_ignore: set[str],\n", ") -> list[tuple[list[str], str]]:\n", " \"\"\"\n", " From a Wikipedia section, return a flattened list of all nested subsections.\n", " Each subsection is a tuple, where:\n", " - the first element is a list of parent subtitles, starting with the page title\n", " - the second element is the text of the subsection (but not any children)\n", " \"\"\"\n", " headings = [str(h) for h in section.filter_headings()]\n", " title = headings[0]\n", " if title.strip(\"=\" + \" \") in sections_to_ignore:\n", " # ^wiki headings are wrapped like \"== Heading ==\"\n", " return []\n", " titles = parent_titles + [title]\n", " full_text = str(section)\n", " section_text = full_text.split(title)[1]\n", " if len(headings) == 1:\n", " return [(titles, section_text)]\n", " else:\n", " first_subtitle = headings[1]\n", " section_text = section_text.split(first_subtitle)[0]\n", " results = [(titles, section_text)]\n", " for subsection in section.get_sections(levels=[len(titles) + 1]):\n", " results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))\n", " return results\n", "\n", "\n", "def all_subsections_from_title(\n", " title: str,\n", " sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,\n", " site_name: str = WIKI_SITE,\n", ") -> list[tuple[list[str], str]]:\n", " \"\"\"From a Wikipedia page title, return a flattened list of all nested subsections.\n", " Each subsection is a tuple, where:\n", " - the first element is a list of parent subtitles, starting with the page title\n", " - the second element is the text of the subsection (but not any children)\n", " \"\"\"\n", " site = mwclient.Site(site_name)\n", " page = site.pages[title]\n", " text = page.text()\n", " parsed_text = mwparserfromhell.parse(text)\n", " headings = [str(h) for h in parsed_text.filter_headings()]\n", " if headings:\n", " summary_text = str(parsed_text).split(headings[0])[0]\n", " else:\n", " summary_text = str(parsed_text)\n", " results = [([title], summary_text)]\n", " for subsection in parsed_text.get_sections(levels=[2]):\n", " results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))\n", " return results\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1838 sections in 179 pages.\n" ] } ], "source": [ "# split pages into sections\n", "# may take ~1 minute per 100 articles\n", "wikipedia_sections = []\n", "for title in titles:\n", " wikipedia_sections.extend(all_subsections_from_title(title))\n", "print(f\"Found {len(wikipedia_sections)} sections in {len(titles)} pages.\")\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Filtered out 89 sections, leaving 1749 sections.\n" ] } ], "source": [ "# clean text\n", "def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:\n", " \"\"\"\n", " Return a cleaned up section with:\n", " - xyz patterns removed\n", " - leading/trailing whitespace removed\n", " \"\"\"\n", " titles, text = section\n", " text = re.sub(r\"\", \"\", text)\n", " text = text.strip()\n", " return (titles, text)\n", "\n", "\n", "wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]\n", "\n", "# filter out short/blank sections\n", "def keep_section(section: tuple[list[str], str]) -> bool:\n", " \"\"\"Return True if the section should be kept, False otherwise.\"\"\"\n", " titles, text = section\n", " if len(text) < 16:\n", " return False\n", " else:\n", " return True\n", "\n", "\n", "original_num_sections = len(wikipedia_sections)\n", "wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]\n", "print(f\"Filtered out {original_num_sections-len(wikipedia_sections)} sections, leaving {len(wikipedia_sections)} sections.\")\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Concerns and controversies at the 2022 Winter Olympics']\n" ] }, { "data": { "text/plain": [ "'{{Short description|Overview of concerns and controversies surrounding the Ga...'" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "['Concerns and controversies at the 2022 Winter Olympics', '==Criticism of host selection==']\n" ] }, { "data": { "text/plain": [ "'American sportscaster [[Bob Costas]] criticized the [[International Olympic C...'" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Cost and climate===']\n" ] }, { "data": { "text/plain": [ "'Several cities withdrew their applications during [[Bids for the 2022 Winter ...'" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "['Concerns and controversies at the 2022 Winter Olympics', '==Organizing concerns and controversies==', '===Promotional song===']\n" ] }, { "data": { "text/plain": [ "'Some commentators alleged that one of the early promotional songs for the [[2...'" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "['Concerns and controversies at the 2022 Winter Olympics', '== Diplomatic boycotts or non-attendance ==']\n" ] }, { "data": { "text/plain": [ "'
\\n[[File:2022 Winter Olympics (Beijing) diplomatic b...'" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# print example data\n", "for ws in wikipedia_sections[:5]:\n", " print(ws[0])\n", " display(ws[1][:77] + \"...\")\n", " print()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll recursively split long sections into smaller sections.\n", "\n", "There's no perfect recipe for splitting text into sections.\n", "\n", "Some tradeoffs include:\n", "- Longer sections may be better for questions that require more context\n", "- Longer sections may be worse for retrieval, as they may have more topics muddled together\n", "- Shorter sections are better for reducing costs (which are proportional to the number of tokens)\n", "- Shorter sections allow more sections to be retrieved, which may help with recall\n", "- Overlapping sections may help prevent answers from being cut by section boundaries\n", "\n", "Here, we'll use a simple approach and limit sections to 1,600 tokens each, recursively halving any sections that are too long. To avoid cutting in the middle of useful sentences, we'll split along paragraph boundaries when possible." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "GPT_MODEL = \"gpt-4o-mini\" # only matters insofar as it selects which tokenizer to use\n", "\n", "\n", "def num_tokens(text: str, model: str = GPT_MODEL) -> int:\n", " \"\"\"Return the number of tokens in a string.\"\"\"\n", " encoding = tiktoken.encoding_for_model(model)\n", " return len(encoding.encode(text))\n", "\n", "\n", "def halved_by_delimiter(string: str, delimiter: str = \"\\n\") -> list[str, str]:\n", " \"\"\"Split a string in two, on a delimiter, trying to balance tokens on each side.\"\"\"\n", " chunks = string.split(delimiter)\n", " if len(chunks) == 1:\n", " return [string, \"\"] # no delimiter found\n", " elif len(chunks) == 2:\n", " return chunks # no need to search for halfway point\n", " else:\n", " total_tokens = num_tokens(string)\n", " halfway = total_tokens // 2\n", " best_diff = halfway\n", " for i, chunk in enumerate(chunks):\n", " left = delimiter.join(chunks[: i + 1])\n", " left_tokens = num_tokens(left)\n", " diff = abs(halfway - left_tokens)\n", " if diff >= best_diff:\n", " break\n", " else:\n", " best_diff = diff\n", " left = delimiter.join(chunks[:i])\n", " right = delimiter.join(chunks[i:])\n", " return [left, right]\n", "\n", "\n", "def truncated_string(\n", " string: str,\n", " model: str,\n", " max_tokens: int,\n", " print_warning: bool = True,\n", ") -> str:\n", " \"\"\"Truncate a string to a maximum number of tokens.\"\"\"\n", " encoding = tiktoken.encoding_for_model(model)\n", " encoded_string = encoding.encode(string)\n", " truncated_string = encoding.decode(encoded_string[:max_tokens])\n", " if print_warning and len(encoded_string) > max_tokens:\n", " print(f\"Warning: Truncated string from {len(encoded_string)} tokens to {max_tokens} tokens.\")\n", " return truncated_string\n", "\n", "\n", "def split_strings_from_subsection(\n", " subsection: tuple[list[str], str],\n", " max_tokens: int = 1000,\n", " model: str = GPT_MODEL,\n", " max_recursion: int = 5,\n", ") -> list[str]:\n", " \"\"\"\n", " Split a subsection into a list of subsections, each with no more than max_tokens.\n", " Each subsection is a tuple of parent titles [H1, H2, ...] and text (str).\n", " \"\"\"\n", " titles, text = subsection\n", " string = \"\\n\\n\".join(titles + [text])\n", " num_tokens_in_string = num_tokens(string)\n", " # if length is fine, return string\n", " if num_tokens_in_string <= max_tokens:\n", " return [string]\n", " # if recursion hasn't found a split after X iterations, just truncate\n", " elif max_recursion == 0:\n", " return [truncated_string(string, model=model, max_tokens=max_tokens)]\n", " # otherwise, split in half and recurse\n", " else:\n", " titles, text = subsection\n", " for delimiter in [\"\\n\\n\", \"\\n\", \". \"]:\n", " left, right = halved_by_delimiter(text, delimiter=delimiter)\n", " if left == \"\" or right == \"\":\n", " # if either half is empty, retry with a more fine-grained delimiter\n", " continue\n", " else:\n", " # recurse on each half\n", " results = []\n", " for half in [left, right]:\n", " half_subsection = (titles, half)\n", " half_strings = split_strings_from_subsection(\n", " half_subsection,\n", " max_tokens=max_tokens,\n", " model=model,\n", " max_recursion=max_recursion - 1,\n", " )\n", " results.extend(half_strings)\n", " return results\n", " # otherwise no split was found, so just truncate (should be very rare)\n", " return [truncated_string(string, model=model, max_tokens=max_tokens)]\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1749 Wikipedia sections split into 2052 strings.\n" ] } ], "source": [ "# split sections into chunks\n", "MAX_TOKENS = 1600\n", "wikipedia_strings = []\n", "for section in wikipedia_sections:\n", " wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))\n", "\n", "print(f\"{len(wikipedia_sections)} Wikipedia sections split into {len(wikipedia_strings)} strings.\")\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Concerns and controversies at the 2022 Winter Olympics\n", "\n", "==Criticism of host selection==\n", "\n", "American sportscaster [[Bob Costas]] criticized the [[International Olympic Committee]]'s (IOC) decision to award the games to China saying \"The IOC deserves all of the disdain and disgust that comes their way for going back to China yet again\" referencing China's human rights record.\n", "\n", "After winning two gold medals and returning to his home country of Sweden skater [[Nils van der Poel]] criticized the IOC's selection of China as the host saying \"I think it is extremely irresponsible to give it to a country that violates human rights as blatantly as the Chinese regime is doing.\" He had declined to criticize China before leaving for the games saying \"I don't think it would be particularly wise for me to criticize the system I'm about to transition to, if I want to live a long and productive life.\"\n" ] } ], "source": [ "# print example data\n", "print(wikipedia_strings[1])\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Embed document chunks\n", "\n", "Now that we've split our library into shorter self-contained strings, we can compute embeddings for each.\n", "\n", "(For large embedding jobs, use a script like [api_request_parallel_processor.py](https://github.com/openai/openai-cookbook/blob/main/examples/api_request_parallel_processor.py) to parallelize requests while throttling to stay under rate limits.)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batch 0 to 999\n", "Batch 1000 to 1999\n", "Batch 2000 to 2999\n" ] } ], "source": [ "EMBEDDING_MODEL = \"text-embedding-3-small\"\n", "BATCH_SIZE = 1000 # you can submit up to 2048 embedding inputs per request\n", "\n", "embeddings = []\n", "for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):\n", " batch_end = batch_start + BATCH_SIZE\n", " batch = wikipedia_strings[batch_start:batch_end]\n", " print(f\"Batch {batch_start} to {batch_end-1}\")\n", " response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)\n", " for i, be in enumerate(response.data):\n", " assert i == be.index # double check embeddings are in same order as input\n", " batch_embeddings = [e.embedding for e in response.data]\n", " embeddings.extend(batch_embeddings)\n", "\n", "df = pd.DataFrame({\"text\": wikipedia_strings, \"embedding\": embeddings})\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Store document chunks and embeddings\n", "\n", "Because this example only uses a few thousand strings, we'll store them in a CSV file.\n", "\n", "(For larger datasets, use a vector database, which will be more performant.)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# save document chunks and embeddings\n", "\n", "SAVE_PATH = \"data/winter_olympics_2022.csv\"\n", "\n", "df.to_csv(SAVE_PATH, index=False)\n" ] } ], "metadata": { "kernelspec": { "display_name": "openai", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }