{ "cells": [ { "cell_type": "markdown", "id": "cb1537e6", "metadata": {}, "source": [ "# Vector Database Introduction\n", "\n", "This notebook takes you through a simple flow to download some data, embed it, and then index and search it using a selection of vector databases. This is a common requirement for customers who want to store and search our embeddings with their own data in a secure environment to support production use cases such as chatbots, topic modelling and more.\n", "\n", "The demo flow is:\n", "- **Setup**: Import packages and set any required variables\n", "- **Load data**: Load a dataset and embed it using OpenAI embeddings\n", "- **Pinecone**\n", " - *Setup*: Here we setup the Python client for Pinecone. For more details go [here](https://docs.pinecone.io/docs/quickstart)\n", " - *Index Data*: We'll create an index with namespaces for __titles__ and __content__\n", " - *Search Data*: We'll test out both namespaces with search queries to confirm it works\n", "- **Weaviate**\n", " - *Setup*: Here we setup the Python client for Weaviate. For more details go [here](https://weaviate.io/developers/weaviate/current/client-libraries/python.html)\n", " - *Index Data*: We'll create an index with __title__ search vectors in it\n", " - *Search Data*: We'll run a few searches to confirm it works\n", "\n", "Once you've run through this notebook you should have a basic understanding of how to setup and use vector databases, and can move on to more complex use cases making use of our embeddings" ] }, { "cell_type": "markdown", "id": "e2b59250", "metadata": {}, "source": [ "## Setup\n", "\n", "Here we import the required libraries and set the embedding model that we'd like to use" ] }, { "cell_type": "code", "execution_count": 98, "id": "5be94df6", "metadata": {}, "outputs": [], "source": [ "import openai\n", "\n", "import tiktoken\n", "from tenacity import retry, wait_random_exponential, stop_after_attempt\n", "from typing import List, Iterator\n", "import concurrent\n", "from tqdm import tqdm\n", "import pandas as pd\n", "from datasets import load_dataset\n", "import numpy as np\n", "import os\n", "\n", "# Pinecone's client library for Python\n", "import pinecone\n", "\n", "# Weaviate's client library for Python\n", "import weaviate\n", "\n", "# I've set this to our new embeddings model, this can be changed to the embedding model of your choice\n", "MODEL = \"text-embedding-ada-002\"\n", "\n", "# Ignore unclosed SSL socket warnings - optional in case you get these errors\n", "import warnings\n", "\n", "warnings.filterwarnings(action=\"ignore\", message=\"unclosed\", category=ResourceWarning)\n", "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) " ] }, { "cell_type": "markdown", "id": "e5d9d2e1", "metadata": {}, "source": [ "## Load data\n", "\n", "In this section we'll source the data for this task, embed it and format it for insertion into a vector database\n", "\n", "*Thanks to Ryan Greene for the template used for the batch ingestion" ] }, { "cell_type": "code", "execution_count": 116, "id": "bd99e08e", "metadata": {}, "outputs": [], "source": [ "# Simple function to take in a list of text objects and return them as a list of embeddings\n", "def get_embeddings(input: List):\n", " response = openai.Embedding.create(\n", " input=input,\n", " model=MODEL,\n", " )[\"data\"]\n", " return [data[\"embedding\"] for data in response]\n", "\n", "def batchify(iterable, n=1):\n", " l = len(iterable)\n", " for ndx in range(0, l, n):\n", " yield iterable[ndx : min(ndx + n, l)]\n", "\n", "# Function for batching and parallel processing the embeddings\n", "@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))\n", "def embed_corpus(\n", " corpus: List[str],\n", " batch_size=64,\n", " num_workers=8,\n", " max_context_len=8191,\n", "):\n", "\n", " # Encode the corpus, truncating to max_context_len\n", " encoding = tiktoken.get_encoding(\"cl100k_base\")\n", " encoded_corpus = [\n", " encoded_article[:max_context_len] for encoded_article in encoding.encode_batch(corpus)\n", " ]\n", "\n", " # Calculate corpus statistics: the number of inputs, the total number of tokens, and the estimated cost to embed\n", " num_tokens = sum(len(article) for article in encoded_corpus)\n", " cost_to_embed_tokens = num_tokens / 1_000 * 0.0004\n", " print(\n", " f\"num_articles={len(encoded_corpus)}, num_tokens={num_tokens}, est_embedding_cost={cost_to_embed_tokens:.2f} USD\"\n", " )\n", "\n", " # Embed the corpus\n", " with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:\n", " \n", " try:\n", " futures = [\n", " executor.submit(get_embeddings, text_batch)\n", " for text_batch in batchify(encoded_corpus, batch_size)\n", " ]\n", "\n", " with tqdm(total=len(encoded_corpus)) as pbar:\n", " for _ in concurrent.futures.as_completed(futures):\n", " pbar.update(batch_size)\n", "\n", " embeddings = []\n", " for future in futures:\n", " data = future.result()\n", " embeddings.extend(data)\n", " \n", " return embeddings\n", " \n", " except Exception as e:\n", " print('Get embeddings failed, returning exception')\n", " \n", " return e\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "0c1c73cb", "metadata": {}, "outputs": [], "source": [ "# We'll use the datasets library to pull the Simple Wikipedia dataset for embedding\n", "dataset = list(load_dataset(\"wikipedia\", \"20220301.simple\")[\"train\"])\n", "# Limited to 50k articles for demo purposes\n", "dataset = dataset[:50_000] " ] }, { "cell_type": "code", "execution_count": 118, "id": "e6ee90ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_articles=50000, num_tokens=18272526, est_embedding_cost=7.31 USD\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "50048it [02:30, 332.26it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "num_articles=50000, num_tokens=202363, est_embedding_cost=0.08 USD\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "50048it [00:53, 942.94it/s] " ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 48.7 s, sys: 1min 19s, total: 2min 7s\n", "Wall time: 5min 53s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "# Embed the article text\n", "dataset_embeddings = embed_corpus([article[\"text\"] for article in dataset])\n", "# Embed the article titles separately\n", "title_embeddings = embed_corpus([article[\"title\"] for article in dataset])" ] }, { "cell_type": "code", "execution_count": 119, "id": "1410daaa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | id | \n", "url | \n", "title | \n", "text | \n", "title_vector | \n", "content_vector | \n", "vector_id | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "https://simple.wikipedia.org/wiki/April | \n", "April | \n", "April is the fourth month of the year in the J... | \n", "[0.00107035250402987, -0.02077057771384716, -0... | \n", "[-0.011253940872848034, -0.013491976074874401,... | \n", "0 | \n", "
1 | \n", "2 | \n", "https://simple.wikipedia.org/wiki/August | \n", "August | \n", "August (Aug.) is the eighth month of the year ... | \n", "[0.0010461278725415468, 0.0008924593566916883,... | \n", "[0.0003609954728744924, 0.007262262050062418, ... | \n", "1 | \n", "
2 | \n", "6 | \n", "https://simple.wikipedia.org/wiki/Art | \n", "Art | \n", "Art is a creative activity that expresses imag... | \n", "[0.0033627033699303865, 0.006122018210589886, ... | \n", "[-0.004959689453244209, 0.015772193670272827, ... | \n", "2 | \n", "
3 | \n", "8 | \n", "https://simple.wikipedia.org/wiki/A | \n", "A | \n", "A or a is the first letter of the English alph... | \n", "[0.015406121499836445, -0.013689860701560974, ... | \n", "[0.024894846603274345, -0.022186409682035446, ... | \n", "3 | \n", "
4 | \n", "9 | \n", "https://simple.wikipedia.org/wiki/Air | \n", "Air | \n", "Air refers to the Earth's atmosphere. Air is a... | \n", "[0.022219523787498474, -0.020443666726350784, ... | \n", "[0.021524671465158463, 0.018522677943110466, -... | \n", "4 | \n", "