add contextual embedding with claude prompt caching (#11)

* add context augmentation while chunking

* add contextual embeddings

* default config should be combined

* fix comments on PR

* update example environment

* update config and api to support env-variable optionality
This commit is contained in:
Arnav Agrawal 2024-12-31 06:58:34 -05:00 committed by GitHub
parent 367dc079e8
commit abccf99974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 157 additions and 23 deletions

View File

@ -1,8 +1,9 @@
MONGODB_URI="..."
OPENAI_API_KEY="..."
# Optional: Only needed if using AWS S3 storage
AWS_ACCESS_KEY="..."
AWS_SECRET_ACCESS_KEY="..."
UNSTRUCTURED_API_KEY="..."
ASSEMBLYAI_API_KEY="..."
JWT_SECRET_KEY="..."
MONGODB_URI="..."
UNSTRUCTURED_API_KEY="..."
OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions
ASSEMBLYAI_API_KEY="..." # Optional: Needed for combined parser
ANTHROPIC_API_KEY="..." # Optional: Needed for contextual parser
AWS_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage
AWS_SECRET_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage

View File

@ -10,7 +10,7 @@ database = "mongodb"
vector_store = "mongodb"
embedding = "openai" # "ollama"
completion = "openai" # "ollama"
parser = "combined" # "unstructured"
parser = "combined" # "unstructured", "contextual"
# Storage Configuration
[storage.local]

View File

@ -27,6 +27,7 @@ from core.storage.s3_storage import S3Storage
from core.storage.local_storage import LocalStorage
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
from core.completion.ollama_completion import OllamaCompletionModel
from core.parser.contextual_parser import ContextualParser
# Initialize FastAPI app
app = FastAPI(title="DataBridge API")
@ -78,6 +79,8 @@ match settings.STORAGE_PROVIDER:
case "local":
storage = LocalStorage(storage_path=settings.STORAGE_PATH)
case "aws-s3":
if not settings.AWS_ACCESS_KEY or not settings.AWS_SECRET_ACCESS_KEY:
raise ValueError("AWS credentials are required for S3 storage")
storage = S3Storage(
aws_access_key=settings.AWS_ACCESS_KEY,
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
@ -90,6 +93,8 @@ match settings.STORAGE_PROVIDER:
# Initialize parser
match settings.PARSER_PROVIDER:
case "combined":
if not settings.ASSEMBLYAI_API_KEY:
raise ValueError("AssemblyAI API key is required for combined parser")
parser = CombinedParser(
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
@ -103,21 +108,34 @@ match settings.PARSER_PROVIDER:
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
)
case "contextual":
if not settings.ANTHROPIC_API_KEY:
raise ValueError("Anthropic API key is required for contextual parser")
parser = ContextualParser(
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
frame_sample_rate=settings.FRAME_SAMPLE_RATE,
anthropic_api_key=settings.ANTHROPIC_API_KEY,
)
case _:
raise ValueError(f"Unsupported parser provider: {settings.PARSER_PROVIDER}")
# Initialize embedding model
match settings.EMBEDDING_PROVIDER:
case "ollama":
embedding_model = OllamaEmbeddingModel(
base_url=settings.OLLAMA_BASE_URL,
model_name=settings.EMBEDDING_MODEL,
)
case "openai":
if not settings.OPENAI_API_KEY:
raise ValueError("OpenAI API key is required for OpenAI embedding model")
embedding_model = OpenAIEmbeddingModel(
api_key=settings.OPENAI_API_KEY,
model_name=settings.EMBEDDING_MODEL,
)
case "ollama":
embedding_model = OllamaEmbeddingModel(
model_name=settings.EMBEDDING_MODEL,
base_url=settings.OLLAMA_BASE_URL,
)
case _:
raise ValueError(f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}")
@ -129,6 +147,8 @@ match settings.COMPLETION_PROVIDER:
base_url=settings.OLLAMA_BASE_URL,
)
case "openai":
if not settings.OPENAI_API_KEY:
raise ValueError("OpenAI API key is required for OpenAI completion model")
completion_model = OpenAICompletionModel(
model_name=settings.COMPLETION_MODEL,
)
@ -184,7 +204,7 @@ async def ingest_text(
operation_type="ingest_text",
user_id=auth.entity_id,
tokens_used=len(request.content.split()), # Approximate token count
metadata=request.metadata.model_dump() if request.metadata else None,
metadata=request.metadata if request.metadata else None,
):
return await document_service.ingest_text(request, auth)
except PermissionError as e:

View File

@ -1,3 +1,4 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from functools import lru_cache
@ -9,13 +10,15 @@ class Settings(BaseSettings):
"""DataBridge configuration settings."""
# Required environment variables (referenced in config.toml)
MONGODB_URI: str = Field(..., env="MONGODB_URI")
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
ASSEMBLYAI_API_KEY: str = Field(..., env="ASSEMBLYAI_API_KEY")
AWS_ACCESS_KEY: str = Field(None, env="AWS_ACCESS_KEY")
AWS_SECRET_ACCESS_KEY: str = Field(None, env="AWS_SECRET_ACCESS_KEY")
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
MONGODB_URI: str = Field(..., env="MONGODB_URI")
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
AWS_ACCESS_KEY: Optional[str] = Field(None, env="AWS_ACCESS_KEY")
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(None, env="AWS_SECRET_ACCESS_KEY")
ASSEMBLYAI_API_KEY: Optional[str] = Field(None, env="ASSEMBLYAI_API_KEY")
OPENAI_API_KEY: Optional[str] = Field(None, env="OPENAI_API_KEY")
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
# Service settings
HOST: str = "localhost"

View File

@ -79,7 +79,7 @@ class CombinedParser(BaseParser):
async def parse_file(
self, file: bytes, content_type: str
) -> Tuple[Dict[str, Any], List[Chunk]]:
"""Parse file content into text chunks"""
"""Parse file content into text chunks. Returns document metadata and a list of chunks"""
is_video = self._is_video_file(file_bytes=file)
if is_video:

View File

@ -0,0 +1,110 @@
import logging
import time
from typing import Any, Dict, List, Tuple
import anthropic
from core.models.chunk import Chunk
from core.parser.base_parser import BaseParser
from core.parser.combined_parser import CombinedParser
logger = logging.getLogger(__name__)
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""
CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""
class ContextualParser(BaseParser):
def __init__(
self,
unstructured_api_key: str,
assemblyai_api_key: str,
chunk_size: int,
chunk_overlap: int,
frame_sample_rate: int,
anthropic_api_key: str,
):
self.combined_parser = CombinedParser(
unstructured_api_key=unstructured_api_key,
assemblyai_api_key=assemblyai_api_key,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
frame_sample_rate=frame_sample_rate,
)
self.llm = anthropic.Anthropic(api_key=anthropic_api_key)
def situate_context(self, doc: str, chunk: str) -> str:
response = self.llm.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1024,
temperature=0.0,
system=[
{
"type": "text",
"text": "You are an AI assistant that situates a chunk within a document for the purposes of improving search retrieval of the chunk.",
},
{
"type": "text",
"text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
"cache_control": {"type": "ephemeral"},
},
],
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
}
],
}
],
)
context = response.content[0]
if context.type == "text":
return context.text
else:
message = f"Anthropic client returned non-text response when situating context for chunk: {chunk} \n Response: {response}"
logger.error(message)
raise ValueError(message)
def situate_all_chunks(self, text: str, chunks: List[Chunk]) -> List[Chunk]:
new_chunks = []
chunks_situated = 0
for chunk in chunks:
context = self.situate_context(text, chunk.content)
content = f"{context}; {chunk.content}"
new_chunk = Chunk(content=content, metadata=chunk.metadata)
new_chunks.append(new_chunk)
logger.info(f"Situating the {chunks_situated}th chunk:\n {new_chunk.content[:100]}")
logger.info("Sleeping to avoid rate limiting...")
time.sleep(1.25)
chunks_situated += 1
return new_chunks
async def parse_file(
self, file: bytes, content_type: str
) -> Tuple[Dict[str, Any], List[Chunk]]:
document_metadata, chunks = await self.combined_parser.parse_file(file, content_type)
document_text = "\n".join([chunk.content for chunk in chunks])
new_chunks = self.situate_all_chunks(document_text, chunks)
return document_metadata, new_chunks
async def split_text(self, text: str) -> List[Chunk]:
chunks = await self.combined_parser.split_text(text)
new_chunks = self.situate_all_chunks(text, chunks)
return new_chunks

View File

@ -91,7 +91,7 @@ class DocumentService:
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
k: int = 20, # from contextual embedding paper
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Union, BinaryIO
from typing import Tuple, Optional
class BaseStorage(ABC):