mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
add contextual embedding with claude prompt caching (#11)
* add context augmentation while chunking * add contextual embeddings * default config should be combined * fix comments on PR * update example environment * update config and api to support env-variable optionality
This commit is contained in:
parent
367dc079e8
commit
abccf99974
15
.env.example
15
.env.example
@ -1,8 +1,9 @@
|
||||
MONGODB_URI="..."
|
||||
OPENAI_API_KEY="..."
|
||||
# Optional: Only needed if using AWS S3 storage
|
||||
AWS_ACCESS_KEY="..."
|
||||
AWS_SECRET_ACCESS_KEY="..."
|
||||
UNSTRUCTURED_API_KEY="..."
|
||||
ASSEMBLYAI_API_KEY="..."
|
||||
JWT_SECRET_KEY="..."
|
||||
MONGODB_URI="..."
|
||||
UNSTRUCTURED_API_KEY="..."
|
||||
|
||||
OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions
|
||||
ASSEMBLYAI_API_KEY="..." # Optional: Needed for combined parser
|
||||
ANTHROPIC_API_KEY="..." # Optional: Needed for contextual parser
|
||||
AWS_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage
|
||||
AWS_SECRET_ACCESS_KEY="..." # Optional: Needed for AWS S3 storage
|
||||
|
@ -10,7 +10,7 @@ database = "mongodb"
|
||||
vector_store = "mongodb"
|
||||
embedding = "openai" # "ollama"
|
||||
completion = "openai" # "ollama"
|
||||
parser = "combined" # "unstructured"
|
||||
parser = "combined" # "unstructured", "contextual"
|
||||
|
||||
# Storage Configuration
|
||||
[storage.local]
|
||||
|
32
core/api.py
32
core/api.py
@ -27,6 +27,7 @@ from core.storage.s3_storage import S3Storage
|
||||
from core.storage.local_storage import LocalStorage
|
||||
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
|
||||
from core.completion.ollama_completion import OllamaCompletionModel
|
||||
from core.parser.contextual_parser import ContextualParser
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="DataBridge API")
|
||||
@ -78,6 +79,8 @@ match settings.STORAGE_PROVIDER:
|
||||
case "local":
|
||||
storage = LocalStorage(storage_path=settings.STORAGE_PATH)
|
||||
case "aws-s3":
|
||||
if not settings.AWS_ACCESS_KEY or not settings.AWS_SECRET_ACCESS_KEY:
|
||||
raise ValueError("AWS credentials are required for S3 storage")
|
||||
storage = S3Storage(
|
||||
aws_access_key=settings.AWS_ACCESS_KEY,
|
||||
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
|
||||
@ -90,6 +93,8 @@ match settings.STORAGE_PROVIDER:
|
||||
# Initialize parser
|
||||
match settings.PARSER_PROVIDER:
|
||||
case "combined":
|
||||
if not settings.ASSEMBLYAI_API_KEY:
|
||||
raise ValueError("AssemblyAI API key is required for combined parser")
|
||||
parser = CombinedParser(
|
||||
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
|
||||
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
|
||||
@ -103,21 +108,34 @@ match settings.PARSER_PROVIDER:
|
||||
chunk_size=settings.CHUNK_SIZE,
|
||||
chunk_overlap=settings.CHUNK_OVERLAP,
|
||||
)
|
||||
case "contextual":
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
raise ValueError("Anthropic API key is required for contextual parser")
|
||||
parser = ContextualParser(
|
||||
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
|
||||
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
|
||||
chunk_size=settings.CHUNK_SIZE,
|
||||
chunk_overlap=settings.CHUNK_OVERLAP,
|
||||
frame_sample_rate=settings.FRAME_SAMPLE_RATE,
|
||||
anthropic_api_key=settings.ANTHROPIC_API_KEY,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported parser provider: {settings.PARSER_PROVIDER}")
|
||||
|
||||
# Initialize embedding model
|
||||
match settings.EMBEDDING_PROVIDER:
|
||||
case "ollama":
|
||||
embedding_model = OllamaEmbeddingModel(
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
)
|
||||
case "openai":
|
||||
if not settings.OPENAI_API_KEY:
|
||||
raise ValueError("OpenAI API key is required for OpenAI embedding model")
|
||||
embedding_model = OpenAIEmbeddingModel(
|
||||
api_key=settings.OPENAI_API_KEY,
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
)
|
||||
case "ollama":
|
||||
embedding_model = OllamaEmbeddingModel(
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}")
|
||||
|
||||
@ -129,6 +147,8 @@ match settings.COMPLETION_PROVIDER:
|
||||
base_url=settings.OLLAMA_BASE_URL,
|
||||
)
|
||||
case "openai":
|
||||
if not settings.OPENAI_API_KEY:
|
||||
raise ValueError("OpenAI API key is required for OpenAI completion model")
|
||||
completion_model = OpenAICompletionModel(
|
||||
model_name=settings.COMPLETION_MODEL,
|
||||
)
|
||||
@ -184,7 +204,7 @@ async def ingest_text(
|
||||
operation_type="ingest_text",
|
||||
user_id=auth.entity_id,
|
||||
tokens_used=len(request.content.split()), # Approximate token count
|
||||
metadata=request.metadata.model_dump() if request.metadata else None,
|
||||
metadata=request.metadata if request.metadata else None,
|
||||
):
|
||||
return await document_service.ingest_text(request, auth)
|
||||
except PermissionError as e:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
@ -9,13 +10,15 @@ class Settings(BaseSettings):
|
||||
"""DataBridge configuration settings."""
|
||||
|
||||
# Required environment variables (referenced in config.toml)
|
||||
MONGODB_URI: str = Field(..., env="MONGODB_URI")
|
||||
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
|
||||
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
|
||||
ASSEMBLYAI_API_KEY: str = Field(..., env="ASSEMBLYAI_API_KEY")
|
||||
AWS_ACCESS_KEY: str = Field(None, env="AWS_ACCESS_KEY")
|
||||
AWS_SECRET_ACCESS_KEY: str = Field(None, env="AWS_SECRET_ACCESS_KEY")
|
||||
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
|
||||
MONGODB_URI: str = Field(..., env="MONGODB_URI")
|
||||
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
|
||||
|
||||
AWS_ACCESS_KEY: Optional[str] = Field(None, env="AWS_ACCESS_KEY")
|
||||
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(None, env="AWS_SECRET_ACCESS_KEY")
|
||||
ASSEMBLYAI_API_KEY: Optional[str] = Field(None, env="ASSEMBLYAI_API_KEY")
|
||||
OPENAI_API_KEY: Optional[str] = Field(None, env="OPENAI_API_KEY")
|
||||
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
|
||||
|
||||
# Service settings
|
||||
HOST: str = "localhost"
|
||||
|
@ -79,7 +79,7 @@ class CombinedParser(BaseParser):
|
||||
async def parse_file(
|
||||
self, file: bytes, content_type: str
|
||||
) -> Tuple[Dict[str, Any], List[Chunk]]:
|
||||
"""Parse file content into text chunks"""
|
||||
"""Parse file content into text chunks. Returns document metadata and a list of chunks"""
|
||||
is_video = self._is_video_file(file_bytes=file)
|
||||
|
||||
if is_video:
|
||||
|
110
core/parser/contextual_parser.py
Normal file
110
core/parser/contextual_parser.py
Normal file
@ -0,0 +1,110 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import anthropic
|
||||
|
||||
from core.models.chunk import Chunk
|
||||
from core.parser.base_parser import BaseParser
|
||||
from core.parser.combined_parser import CombinedParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DOCUMENT_CONTEXT_PROMPT = """
|
||||
<document>
|
||||
{doc_content}
|
||||
</document>
|
||||
"""
|
||||
|
||||
CHUNK_CONTEXT_PROMPT = """
|
||||
Here is the chunk we want to situate within the whole document
|
||||
<chunk>
|
||||
{chunk_content}
|
||||
</chunk>
|
||||
|
||||
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
|
||||
Answer only with the succinct context and nothing else.
|
||||
"""
|
||||
|
||||
|
||||
class ContextualParser(BaseParser):
|
||||
def __init__(
|
||||
self,
|
||||
unstructured_api_key: str,
|
||||
assemblyai_api_key: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
frame_sample_rate: int,
|
||||
anthropic_api_key: str,
|
||||
):
|
||||
self.combined_parser = CombinedParser(
|
||||
unstructured_api_key=unstructured_api_key,
|
||||
assemblyai_api_key=assemblyai_api_key,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
frame_sample_rate=frame_sample_rate,
|
||||
)
|
||||
self.llm = anthropic.Anthropic(api_key=anthropic_api_key)
|
||||
|
||||
def situate_context(self, doc: str, chunk: str) -> str:
|
||||
response = self.llm.messages.create(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=1024,
|
||||
temperature=0.0,
|
||||
system=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an AI assistant that situates a chunk within a document for the purposes of improving search retrieval of the chunk.",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
context = response.content[0]
|
||||
if context.type == "text":
|
||||
return context.text
|
||||
else:
|
||||
message = f"Anthropic client returned non-text response when situating context for chunk: {chunk} \n Response: {response}"
|
||||
logger.error(message)
|
||||
raise ValueError(message)
|
||||
|
||||
def situate_all_chunks(self, text: str, chunks: List[Chunk]) -> List[Chunk]:
|
||||
new_chunks = []
|
||||
chunks_situated = 0
|
||||
for chunk in chunks:
|
||||
context = self.situate_context(text, chunk.content)
|
||||
content = f"{context}; {chunk.content}"
|
||||
new_chunk = Chunk(content=content, metadata=chunk.metadata)
|
||||
new_chunks.append(new_chunk)
|
||||
logger.info(f"Situating the {chunks_situated}th chunk:\n {new_chunk.content[:100]}")
|
||||
logger.info("Sleeping to avoid rate limiting...")
|
||||
time.sleep(1.25)
|
||||
chunks_situated += 1
|
||||
return new_chunks
|
||||
|
||||
async def parse_file(
|
||||
self, file: bytes, content_type: str
|
||||
) -> Tuple[Dict[str, Any], List[Chunk]]:
|
||||
document_metadata, chunks = await self.combined_parser.parse_file(file, content_type)
|
||||
document_text = "\n".join([chunk.content for chunk in chunks])
|
||||
new_chunks = self.situate_all_chunks(document_text, chunks)
|
||||
return document_metadata, new_chunks
|
||||
|
||||
async def split_text(self, text: str) -> List[Chunk]:
|
||||
chunks = await self.combined_parser.split_text(text)
|
||||
new_chunks = self.situate_all_chunks(text, chunks)
|
||||
return new_chunks
|
@ -91,7 +91,7 @@ class DocumentService:
|
||||
query: str,
|
||||
auth: AuthContext,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
k: int = 20, # from contextual embedding paper
|
||||
min_score: float = 0.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Union, BinaryIO
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
|
Loading…
x
Reference in New Issue
Block a user