basic usage for string content works

This commit is contained in:
Adityavardhan Agrawal 2024-11-16 14:37:01 -05:00
parent 1a926c7be0
commit 3251236fbe
13 changed files with 39 additions and 332 deletions

View File

@ -70,8 +70,8 @@ class ServiceConfig:
try:
self.vector_store = MongoDBAtlasVectorStore(
connection_string=os.getenv("MONGODB_URI"),
database_name=os.getenv("DB_NAME", "databridge"),
collection_name=os.getenv("COLLECTION_NAME", "embeddings")
database_name=os.getenv("DB_NAME", "DataBridgeTest"),
collection_name=os.getenv("COLLECTION_NAME", "test")
)
self.embedding_model = OpenAIEmbeddingModel(
@ -181,10 +181,10 @@ async def ingest_document(
# Parse into chunks
chunk_texts = service.parser.parse(request.content, request.metadata)
embeddings = await service.embedding_model.embed_for_ingestion(chunk_texts)
# Create embeddings and chunks
chunks = []
for chunk_text in chunk_texts:
embedding = await service.embedding_model.embed(chunk_text)
for embedding, chunk_text in zip(embeddings, chunk_texts):
chunk = DocumentChunk(chunk_text, embedding, doc.id)
chunk.metadata = {
'owner_id': owner_id,
@ -212,12 +212,11 @@ async def query_documents(
All configuration and credentials are handled server-side.
"""
logger.info(f"Processing query for owner {owner_id}")
print("ADILOG ")
# Create plan
plan = service.planner.plan_retrieval(request.query, k=request.k)
# Get query embedding
query_embedding = await service.embedding_model.embed(request.query)
query_embedding = await service.embedding_model.embed_for_query(request.query)
# Query vector store
chunks = service.vector_store.query_similar(

View File

@ -1,35 +0,0 @@
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import jwt
security = HTTPBearer()
class DataBridgeAuth:
def __init__(self, secret_key: str):
self.secret_key = secret_key
async def __call__(self, request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
try:
token = credentials.credentials
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
# Validate owner_id from token matches header
owner_id = request.headers.get("X-Owner-ID")
if owner_id != payload.get("owner_id"):
raise HTTPException(
status_code=401,
detail="Owner ID mismatch"
)
return owner_id
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=401,
detail="Token has expired"
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=401,
detail="Invalid token"
)

View File

@ -1,106 +0,0 @@
from typing import Dict, Any, List
from .databridge_uri import DataBridgeURI
from .document import Document, DocumentChunk
from .vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from .embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from .parser.unstructured_parser import UnstructuredAPIParser
from .planner.simple_planner import SimpleRAGPlanner
class DataBridge:
"""
DataBridge with owner authentication and authorization.
Configured via URI containing owner credentials.
"""
def __init__(self, uri: str):
# Parse URI and initialize configuration
self.config = DataBridgeURI(uri)
# Initialize components
self._init_components()
def _init_components(self):
"""Initialize all required components using the URI configuration"""
self.embedding_model = OpenAIEmbeddingModel(
api_key=self.config.openai_api_key,
model_name=self.config.embedding_model
)
self.parser = UnstructuredAPIParser(
api_key=self.config.unstructured_api_key,
chunk_size=1000,
chunk_overlap=200
)
self.vector_store = MongoDBAtlasVectorStore(
connection_string=self.config.mongo_uri,
database_name=self.config.db_name,
collection_name=self.config.collection_name
)
self.planner = SimpleRAGPlanner(default_k=4)
async def ingest_document(
self,
content: str,
metadata: Dict[str, Any]
) -> Document:
"""
Ingest a document using the owner ID from the URI configuration.
"""
# Add owner_id to metadata
metadata['owner_id'] = self.config.owner_id
# Create document
doc = Document(content, metadata, self.config.owner_id)
# Parse into chunks
chunk_texts = self.parser.parse(content, metadata)
# Create embeddings and chunks
for chunk_text in chunk_texts:
embedding = await self.embedding_model.embed(chunk_text)
chunk = DocumentChunk(chunk_text, embedding, doc.id)
chunk.metadata = {'owner_id': self.config.owner_id}
doc.chunks.append(chunk)
# Store in vector store
success = self.vector_store.store_embeddings(doc.chunks)
if not success:
raise Exception("Failed to store embeddings")
return doc
async def query(
self,
query: str,
**kwargs
) -> List[Dict[str, Any]]:
"""
Query the document store using the owner ID from the URI configuration.
"""
# Create plan
plan = self.planner.plan_retrieval(query, **kwargs)
# Get query embedding
query_embedding = await self.embedding_model.embed(query)
# Execute plan
chunks = self.vector_store.query_similar(
query_embedding,
k=plan["k"],
owner_id=self.config.owner_id
)
# Format results
results = []
for chunk in chunks:
results.append({
"content": chunk.content,
"doc_id": chunk.doc_id,
"chunk_id": chunk.id,
"score": chunk.score if hasattr(chunk, "score") else None
})
return results

View File

@ -1,61 +0,0 @@
from urllib.parse import urlparse, parse_qs
from typing import Optional, Dict, Any
import os
import jwt
from datetime import datetime, timedelta
class DataBridgeURI:
"""
Handles parsing and validation of DataBridge URIs with owner authentication
Format: databridge://<owner_id>:<auth_token>@host/path?params
"""
def __init__(self, uri: str):
self.uri = uri
self._parse_uri()
def _parse_uri(self):
parsed = urlparse(self.uri)
query_params = parse_qs(parsed.query)
# Parse authentication info from netloc
auth_parts = parsed.netloc.split('@')[0].split(':')
if len(auth_parts) != 2:
raise ValueError("URI must include owner_id and auth_token")
self.owner_id = auth_parts[0]
self.auth_token = auth_parts[1]
# Validate and decode auth token
try:
self._validate_auth_token()
except Exception as e:
raise ValueError(f"Invalid auth token: {str(e)}")
# Get the original MongoDB URI from environment - use it as is
self.mongo_uri = os.getenv("MONGODB_URI")
if not self.mongo_uri:
raise ValueError("MONGODB_URI environment variable not set")
# Get configuration from query parameters
self.openai_api_key = query_params.get('openai_key', [os.getenv('OPENAI_API_KEY', '')])[0]
self.unstructured_api_key = query_params.get('unstructured_key', [os.getenv('UNSTRUCTURED_API_KEY', '')])[0]
self.db_name = query_params.get('db', ['brandsyncaidb'])[0]
self.collection_name = query_params.get('collection', ['kb_chunked_embeddings'])[0]
self.embedding_model = query_params.get('embedding_model', ['text-embedding-3-small'])[0]
# Validate required fields
if not all([self.mongo_uri, self.openai_api_key, self.unstructured_api_key]):
raise ValueError("Missing required configuration in DataBridge URI")
def _validate_auth_token(self):
"""Validate the auth token and extract any additional claims"""
try:
decoded = jwt.decode(self.auth_token, 'your-secret-key', algorithms=['HS256'])
if decoded.get('owner_id') != self.owner_id:
raise ValueError("Token owner_id mismatch")
self.auth_claims = decoded
except jwt.ExpiredSignatureError:
raise ValueError("Auth token has expired")
except jwt.InvalidTokenError:
raise ValueError("Invalid auth token")

View File

@ -4,6 +4,11 @@ from typing import List, Union
class BaseEmbeddingModel(ABC):
@abstractmethod
async def embed(self, text: Union[str, List[str]]) -> List[float]:
async def embed_for_ingestion(self, text: Union[str, List[str]]) -> List[float]:
"""Generate embeddings for input text"""
pass
@abstractmethod
async def embed_for_query(self, text: str) -> List[float]:
"""Generate embeddings for input text"""
pass

View File

@ -1,23 +1,25 @@
from typing import List, Union
import openai
from openai import OpenAI
from .base_embedding_model import BaseEmbeddingModel
class OpenAIEmbeddingModel(BaseEmbeddingModel):
def __init__(self, api_key: str, model_name: str = "text-embedding-3-small"):
self.client = openai.Client(api_key=api_key)
self.client = OpenAI(api_key=api_key)
self.model_name = model_name
async def embed(self, text: Union[str, List[str]]) -> List[float]:
if isinstance(text, str):
text = [text]
async def embed_for_ingestion(self, text: Union[str, List[str]]) -> List[List[float]]:
response = self.client.embeddings.create(
model=self.model_name,
input=text
)
if len(text) == 1:
return response.data[0].embedding
return [item.embedding for item in response.data]
async def embed_for_query(self, text: str) -> List[float]:
response = self.client.embeddings.create(
model=self.model_name,
input=text
)
return response.data[0].embedding

View File

@ -1,27 +0,0 @@
from fastapi import FastAPI, Depends
from .api import app as api_app
from .auth import DataBridgeAuth
import os
app = FastAPI()
auth = DataBridgeAuth(secret_key=os.getenv("JWT_SECRET_KEY", "your-secret-key"))
# Mount the API with authentication
app.mount("/api/v1", api_app)
# Add authentication middleware to all routes
@app.middleware("http")
async def authenticate_requests(request: Request, call_next):
if request.url.path.startswith("/api/v1"):
try:
await auth(request)
except HTTPException as e:
return JSONResponse(
status_code=e.status_code,
content={"detail": e.detail}
)
return await call_next(request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -34,7 +34,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
name=self.index_name,
vectorSearchOptions={
"dimensions": 1536, # For OpenAI embeddings
"similarity": "cosine"
"similarity": "dotProduct"
}
)
except Exception as e:
@ -52,14 +52,12 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"owner_id": chunk.metadata.get("owner_id"),
"metadata": chunk.metadata
}
print("BHAU")
print(doc)
documents.append(doc)
if documents:
# Use ordered=False to continue even if some inserts fail
result = self.collection.insert_many(documents, ordered=False)
print(result)
return len(result.inserted_ids) > 0
return True
@ -77,7 +75,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"""Find similar chunks using MongoDB Atlas Vector Search."""
base_filter = {"owner_id": owner_id}
if filters:
base_filter.update(filters)
filters.update(base_filter)
try:
pipeline = [
@ -88,15 +86,21 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"queryVector": query_embedding,
"numCandidates": k * 10,
"limit": k,
"filter": base_filter
"filter": filters if filters else base_filter
}
},
{
"$project": {
"score": {"$meta": "vectorSearchScore"},
"text": 1,
"embedding": 1,
"doc_id": 1,
"metadata": 1
}
}
]
# print("ADILOG: " + str(pipeline))
results = list(self.collection.aggregate(pipeline))
print("ADILOG")
print(results)
chunks = []
for result in results:

View File

View File

@ -1,76 +0,0 @@
import sys; sys.path.append('.')
from datetime import datetime, timedelta, UTC
import base64
from core.databridge import DataBridge
import jwt
import os
from dotenv import load_dotenv
def create_databridge_uri() -> str:
"""Create DataBridge URI from environment variables"""
load_dotenv()
# Get credentials from environment
mongo_uri = os.getenv("MONGODB_URI")
openai_key = os.getenv("OPENAI_API_KEY")
unstructured_key = os.getenv("UNSTRUCTURED_API_KEY")
owner_id = os.getenv("DATABRIDGE_OWNER", "admin")
# Validate required credentials
if not all([mongo_uri, openai_key, unstructured_key]):
raise ValueError("Missing required environment variables")
# Generate auth token
auth_token = jwt.encode(
{
'owner_id': owner_id,
'exp': datetime.now(UTC) + timedelta(days=30)
},
'your-secret-key', # In production, use proper secret
algorithm='HS256'
)
# For DataBridge URI, use any host identifier (it won't affect MongoDB connection)
uri = (
f"databridge://{owner_id}:{auth_token}@databridge.local"
f"?openai_key={openai_key}"
f"&unstructured_key={unstructured_key}"
f"&db=brandsyncaidb"
f"&collection=kb_chunked_embeddings"
)
return uri
async def main():
# Initialize DataBridge
bridge = DataBridge(create_databridge_uri())
# Example: Ingest a PDF document
with open("examples/sample.pdf", "rb") as f:
pdf_content = base64.b64encode(f.read()).decode()
await bridge.ingest_document(
content=pdf_content,
metadata={
"content_type": "application/pdf",
"is_base64": True,
"title": "Sample PDF"
}
)
# Example: Query documents
results = await bridge.query(
query="What is machine learning?",
k=4
)
for result in results:
print(f"Content: {result['content'][:200]}...")
print(f"Score: {result['score']}\n")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@ -83,7 +83,7 @@ async def example_pdf():
if not uri:
raise ValueError("Please set DATABRIDGE_URI environment variable")
# Path to a sample PDF in the examples directory
pdf_path = Path(__file__).parent / "sample.pdf"
if not pdf_path.exists():
print("× sample.pdf not found in examples directory")

View File

@ -2,6 +2,7 @@ import uvicorn
import os
from dotenv import load_dotenv
def main():
# Load environment variables from .env file
load_dotenv()
@ -15,6 +16,7 @@ def main():
]
missing = [var for var in required_vars if not os.getenv(var)]
if missing:
raise ValueError(f"Missing required environment variables: {', '.join(missing)}")
@ -27,4 +29,4 @@ def main():
)
if __name__ == "__main__":
main()
main()