Add completion sources and batch retrieval for docs and chunks (#51)

This commit is contained in:
Adityavardhan Agrawal 2025-03-09 18:42:04 -04:00 committed by GitHub
parent 8c77da3708
commit 38683df0f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 672 additions and 11 deletions

View File

@ -3,7 +3,6 @@ from datetime import datetime, UTC, timedelta
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional
from core.models.completion import CompletionResponse
from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import jwt
@ -12,6 +11,7 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest
from core.models.completion import ChunkSource, CompletionResponse
from core.models.documents import Document, DocumentResult, ChunkResult
from core.models.auth import AuthContext, EntityType
from core.parser.databridge_parser import DatabridgeParser
@ -398,6 +398,38 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/documents", response_model=List[Document])
async def batch_get_documents(document_ids: List[str], auth: AuthContext = Depends(verify_token)):
"""Retrieve multiple documents by their IDs in a single batch operation."""
try:
async with telemetry.track_operation(
operation_type="batch_get_documents",
user_id=auth.entity_id,
metadata={
"document_count": len(document_ids),
},
):
return await document_service.batch_retrieve_documents(document_ids, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/chunks", response_model=List[ChunkResult])
async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Depends(verify_token)):
"""Retrieve specific chunks by their document ID and chunk number in a single batch operation."""
try:
async with telemetry.track_operation(
operation_type="batch_get_chunks",
user_id=auth.entity_id,
metadata={
"chunk_count": len(chunk_ids),
},
):
return await document_service.batch_retrieve_chunks(chunk_ids, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/query", response_model=CompletionResponse)

View File

@ -1,8 +1,5 @@
from core.completion.base_completion import (
BaseCompletionModel,
CompletionRequest,
CompletionResponse,
)
from core.completion.base_completion import BaseCompletionModel
from core.models.completion import CompletionRequest, CompletionResponse
from ollama import AsyncClient
BASE_64_PREFIX = "data:image/png;base64,"

View File

@ -1,4 +1,5 @@
from .base_completion import BaseCompletionModel, CompletionRequest, CompletionResponse
from .base_completion import BaseCompletionModel
from core.models.completion import CompletionRequest, CompletionResponse
class OpenAICompletionModel(BaseCompletionModel):

View File

@ -23,6 +23,21 @@ class BaseDatabase(ABC):
Returns: Document if found and accessible, None otherwise
"""
pass
@abstractmethod
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that were found and user has access to
"""
pass
@abstractmethod
async def get_documents(

View File

@ -79,6 +79,49 @@ class MongoDatabase(BaseDatabase):
except PyMongoError as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
raise e
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
# Build access filter
access_filter = self._build_access_filter(auth)
# Query documents with both document IDs and access check in a single query
query = {
"$and": [
{"external_id": {"$in": document_ids}},
access_filter
]
}
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
cursor = self.collection.find(query)
documents = []
async for doc_dict in cursor:
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except PyMongoError as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,

View File

@ -149,10 +149,67 @@ class PostgresDatabase(BaseDatabase):
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query documents with both document IDs and access check in a single query
query = (
select(DocumentModel)
.where(DocumentModel.external_id.in_(document_ids))
.where(text(f"({access_filter})"))
)
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
result = await session.execute(query)
doc_models = result.scalars().all()
documents = []
for doc_model in doc_models:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
}
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except Exception as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,

View File

@ -2,12 +2,20 @@ from pydantic import BaseModel
from typing import Dict, List, Optional
class ChunkSource(BaseModel):
"""Source information for a chunk used in completion"""
document_id: str
chunk_number: int
class CompletionResponse(BaseModel):
"""Response from completion generation"""
completion: str
usage: Dict[str, int]
finish_reason: Optional[str] = None
sources: List[ChunkSource] = []
class CompletionRequest(BaseModel):

View File

@ -17,7 +17,7 @@ from core.vector_store.base_vector_store import BaseVectorStore
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.parser.base_parser import BaseParser
from core.completion.base_completion import BaseCompletionModel
from core.completion.base_completion import CompletionRequest, CompletionResponse
from core.models.completion import CompletionRequest, CompletionResponse, ChunkSource
import logging
from core.reranker.base_reranker import BaseReranker
from core.config import get_settings
@ -148,6 +148,77 @@ class DocumentService:
documents = list(results.values())
logger.info(f"Returning {len(documents)} document results")
return documents
async def batch_retrieve_documents(
self,
document_ids: List[str],
auth: AuthContext
) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that user has access to
"""
if not document_ids:
return []
# Use the database's batch retrieval method
documents = await self.db.get_documents_by_id(document_ids, auth)
logger.info(f"Batch retrieved {len(documents)} documents out of {len(document_ids)} requested")
return documents
async def batch_retrieve_chunks(
self,
chunk_ids: List[ChunkSource],
auth: AuthContext
) -> List[ChunkResult]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
chunk_ids: List of ChunkSource objects with document_id and chunk_number
auth: Authentication context
Returns:
List of ChunkResult objects
"""
if not chunk_ids:
return []
# Collect unique document IDs to check authorization in a single query
doc_ids = list({source.document_id for source in chunk_ids})
# Find authorized documents in a single query
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth)
authorized_doc_ids = {doc.external_id for doc in authorized_docs}
# Filter sources to only include authorized documents
authorized_sources = [
source for source in chunk_ids
if source.document_id in authorized_doc_ids
]
if not authorized_sources:
return []
# Create list of (document_id, chunk_number) tuples for vector store query
chunk_identifiers = [
(source.document_id, source.chunk_number)
for source in authorized_sources
]
# Retrieve the chunks from vector store in a single query
chunks = await self.vector_store.get_chunks_by_id(chunk_identifiers)
# Convert to chunk results
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Batch retrieved {len(results)} chunks out of {len(chunk_ids)} requested")
return results
async def query(
self,
@ -169,6 +240,12 @@ class DocumentService:
documents = await self._create_document_results(auth, chunks)
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
# Collect sources information
sources = [
ChunkSource(document_id=chunk.document_id, chunk_number=chunk.chunk_number)
for chunk in chunks
]
# Generate completion
request = CompletionRequest(
@ -179,6 +256,10 @@ class DocumentService:
)
response = await self.completion_model.complete(request)
# Add sources information at the document service level
response.sources = sources
return response
async def ingest_text(

View File

@ -18,3 +18,19 @@ class BaseVectorStore(ABC):
) -> List[DocumentChunk]:
"""Find similar chunks"""
pass
@abstractmethod
async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
"""
Retrieve specific chunks by document ID and chunk number.
Args:
chunk_identifiers: List of (document_id, chunk_number) tuples
Returns:
List of DocumentChunk objects
"""
pass

View File

@ -132,3 +132,52 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
logger.error(f"MongoDB error: {e._message}")
logger.error(f"Error querying similar chunks: {str(e)}")
raise e
async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
"""
Retrieve specific chunks by document ID and chunk number in a single database query.
Args:
chunk_identifiers: List of (document_id, chunk_number) tuples
Returns:
List of DocumentChunk objects
"""
try:
if not chunk_identifiers:
return []
# Create a query with $or to find multiple chunks in a single query
query = {"$or": []}
for doc_id, chunk_num in chunk_identifiers:
query["$or"].append({
"document_id": doc_id,
"chunk_number": chunk_num
})
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
# Find all matching chunks in a single database query
cursor = self.collection.find(query)
chunks = []
async for result in cursor:
chunk = DocumentChunk(
document_id=result["document_id"],
chunk_number=result["chunk_number"],
content=result["content"],
embedding=[], # Don't send embeddings back
metadata=result.get("metadata", {}),
score=0.0, # No relevance score for direct retrieval
)
chunks.append(chunk)
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
return chunks
except PyMongoError as e:
logger.error(f"Error retrieving chunks by ID: {str(e)}")
return []

View File

@ -269,6 +269,62 @@ class MultiVectorStore(BaseVectorStore):
# raise e
# return []
async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
"""
Retrieve specific chunks by document ID and chunk number in a single database query.
Args:
chunk_identifiers: List of (document_id, chunk_number) tuples
Returns:
List of DocumentChunk objects
"""
# try:
if not chunk_identifiers:
return []
# Construct the WHERE clause with OR conditions
conditions = []
for doc_id, chunk_num in chunk_identifiers:
conditions.append(f"(document_id = '{doc_id}' AND chunk_number = {chunk_num})")
where_clause = " OR ".join(conditions)
# Build and execute query
query = f"""
SELECT document_id, chunk_number, content, chunk_metadata
FROM multi_vector_embeddings
WHERE {where_clause}
"""
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
result = self.conn.execute(query).fetchall()
# Convert to DocumentChunks
chunks = []
for row in result:
try:
metadata = eval(row[3]) if row[3] else {}
except (ValueError, SyntaxError):
metadata = {}
chunk = DocumentChunk(
document_id=row[0],
chunk_number=row[1],
content=row[2],
embedding=[], # Don't send embeddings back
metadata=metadata,
score=0.0, # No relevance score for direct retrieval
)
chunks.append(chunk)
logger.info(f"Found {len(chunks)} chunks in batch retrieval from multi-vector store")
return chunks
def close(self):
"""Close the database connection."""
if self.conn:

View File

@ -178,3 +178,66 @@ class PGVectorStore(BaseVectorStore):
except Exception as e:
logger.error(f"Error querying similar chunks: {str(e)}")
return []
async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
"""
Retrieve specific chunks by document ID and chunk number in a single database query.
Args:
chunk_identifiers: List of (document_id, chunk_number) tuples
Returns:
List of DocumentChunk objects
"""
try:
if not chunk_identifiers:
return []
async with self.async_session() as session:
# Create a list of OR conditions for the query
conditions = []
for doc_id, chunk_num in chunk_identifiers:
conditions.append(
text(f"(document_id = '{doc_id}' AND chunk_number = {chunk_num})")
)
# Join conditions with OR
or_condition = text(" OR ".join(f"({condition.text})" for condition in conditions))
# Build query to find all matching chunks in a single query
query = select(VectorEmbedding).where(or_condition)
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
# Execute query
result = await session.execute(query)
chunk_models = result.scalars().all()
# Convert to DocumentChunk objects
chunks = []
for chunk_model in chunk_models:
# Convert stored metadata string back to dict
try:
metadata = eval(chunk_model.chunk_metadata) if chunk_model.chunk_metadata else {}
except Exception:
metadata = {}
chunk = DocumentChunk(
document_id=chunk_model.document_id,
chunk_number=chunk_model.chunk_number,
content=chunk_model.content,
embedding=[], # Don't send embeddings back
metadata=metadata,
score=0.0, # No relevance score for direct retrieval
)
chunks.append(chunk)
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
return chunks
except Exception as e:
logger.error(f"Error retrieving chunks by ID: {str(e)}")
return []

View File

@ -14,6 +14,7 @@ from .models import (
DocumentResult,
CompletionResponse,
IngestTextRequest,
ChunkSource,
)
from .rules import Rule
@ -452,6 +453,105 @@ class AsyncDataBridge:
"""
response = await self._request("GET", f"documents/{document_id}")
return Document(**response)
async def batch_get_documents(self, document_ids: List[str]) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
document_ids: List of document IDs to retrieve
Returns:
List[Document]: List of document metadata for found documents
Example:
```python
docs = await db.batch_get_documents(["doc_123", "doc_456", "doc_789"])
for doc in docs:
print(f"Document {doc.external_id}: {doc.metadata.get('title')}")
```
"""
response = await self._request("POST", "batch/documents", data=document_ids)
return [Document(**doc) for doc in response]
async def batch_get_chunks(self, sources: List[Union[ChunkSource, Dict[str, Any]]]) -> List[FinalChunkResult]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
sources: List of ChunkSource objects or dictionaries with document_id and chunk_number
Returns:
List[FinalChunkResult]: List of chunk results
Example:
```python
# Using dictionaries
sources = [
{"document_id": "doc_123", "chunk_number": 0},
{"document_id": "doc_456", "chunk_number": 2}
]
# Or using ChunkSource objects
from databridge.models import ChunkSource
sources = [
ChunkSource(document_id="doc_123", chunk_number=0),
ChunkSource(document_id="doc_456", chunk_number=2)
]
chunks = await db.batch_get_chunks(sources)
for chunk in chunks:
print(f"Chunk from {chunk.document_id}, number {chunk.chunk_number}: {chunk.content[:50]}...")
```
"""
# Convert to list of dictionaries if needed
source_dicts = []
for source in sources:
if isinstance(source, dict):
source_dicts.append(source)
else:
source_dicts.append(source.model_dump())
response = await self._request("POST", "batch/chunks", data=source_dicts)
chunks = [ChunkResult(**r) for r in response]
final_chunks = []
for chunk in chunks:
if chunk.metadata.get("is_image"):
try:
# Handle data URI format "data:image/png;base64,..."
content = chunk.content
if content.startswith("data:"):
# Extract the base64 part after the comma
content = content.split(",", 1)[1]
# Now decode the base64 string
import base64
import io
from PIL import Image
image_bytes = base64.b64decode(content)
content = Image.open(io.BytesIO(image_bytes))
except Exception as e:
print(f"Error processing image: {str(e)}")
# Fall back to using the content as text
content = chunk.content
else:
content = chunk.content
final_chunks.append(
FinalChunkResult(
content=content,
score=chunk.score,
document_id=chunk.document_id,
chunk_number=chunk.chunk_number,
metadata=chunk.metadata,
content_type=chunk.content_type,
filename=chunk.filename,
download_url=chunk.download_url,
)
)
return final_chunks
async def create_cache(
self,

View File

@ -59,11 +59,21 @@ class DocumentResult(BaseModel):
content: DocumentContent = Field(..., description="Document content or URL")
class ChunkSource(BaseModel):
"""Source information for a chunk used in completion"""
document_id: str = Field(..., description="ID of the source document")
chunk_number: int = Field(..., description="Chunk number within the document")
class CompletionResponse(BaseModel):
"""Completion response model"""
completion: str
usage: Dict[str, int]
sources: List[ChunkSource] = Field(
default_factory=list, description="Sources of chunks used in the completion"
)
class IngestTextRequest(BaseModel):

View File

@ -12,7 +12,7 @@ import jwt
from pydantic import BaseModel, Field
import requests
from .models import Document, ChunkResult, DocumentResult, CompletionResponse, IngestTextRequest
from .models import Document, ChunkResult, DocumentResult, CompletionResponse, IngestTextRequest, ChunkSource
from .rules import Rule
# Type alias for rules
@ -487,6 +487,102 @@ class DataBridge:
"""
response = self._request("GET", f"documents/{document_id}")
return Document(**response)
def batch_get_documents(self, document_ids: List[str]) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
document_ids: List of document IDs to retrieve
Returns:
List[Document]: List of document metadata for found documents
Example:
```python
docs = db.batch_get_documents(["doc_123", "doc_456", "doc_789"])
for doc in docs:
print(f"Document {doc.external_id}: {doc.metadata.get('title')}")
```
"""
response = self._request("POST", "batch/documents", data=document_ids)
return [Document(**doc) for doc in response]
def batch_get_chunks(self, sources: List[Union[ChunkSource, Dict[str, Any]]]) -> List[FinalChunkResult]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
sources: List of ChunkSource objects or dictionaries with document_id and chunk_number
Returns:
List[FinalChunkResult]: List of chunk results
Example:
```python
# Using dictionaries
sources = [
{"document_id": "doc_123", "chunk_number": 0},
{"document_id": "doc_456", "chunk_number": 2}
]
# Or using ChunkSource objects
from databridge.models import ChunkSource
sources = [
ChunkSource(document_id="doc_123", chunk_number=0),
ChunkSource(document_id="doc_456", chunk_number=2)
]
chunks = db.batch_get_chunks(sources)
for chunk in chunks:
print(f"Chunk from {chunk.document_id}, number {chunk.chunk_number}: {chunk.content[:50]}...")
```
"""
# Convert to list of dictionaries if needed
source_dicts = []
for source in sources:
if isinstance(source, dict):
source_dicts.append(source)
else:
source_dicts.append(source.model_dump())
response = self._request("POST", "batch/chunks", data=source_dicts)
chunks = [ChunkResult(**r) for r in response]
final_chunks = []
for chunk in chunks:
if chunk.metadata.get("is_image"):
try:
# Handle data URI format "data:image/png;base64,..."
content = chunk.content
if content.startswith("data:"):
# Extract the base64 part after the comma
content = content.split(",", 1)[1]
# Now decode the base64 string
image_bytes = base64.b64decode(content)
content = Image.open(io.BytesIO(image_bytes))
except Exception as e:
print(f"Error processing image: {str(e)}")
# Fall back to using the content as text
content = chunk.content
else:
content = chunk.content
final_chunks.append(
FinalChunkResult(
content=content,
score=chunk.score,
document_id=chunk.document_id,
chunk_number=chunk.chunk_number,
metadata=chunk.metadata,
content_type=chunk.content_type,
filename=chunk.filename,
download_url=chunk.download_url,
)
)
return final_chunks
def create_cache(
self,

View File

@ -190,6 +190,38 @@ class DB:
"""Get document metadata by ID"""
doc = self._client.get_document(document_id)
return doc.model_dump()
def batch_get_documents(self, document_ids: List[str]) -> List[dict]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
document_ids: List of document IDs to retrieve
Returns:
List of document metadata
"""
docs = self._client.batch_get_documents(document_ids)
return [doc.model_dump() for doc in docs]
def batch_get_chunks(self, sources: List[dict]) -> List[dict]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
sources: List of dictionaries with document_id and chunk_number fields
Returns:
List of chunk results
Example:
sources = [
{"document_id": "doc_123", "chunk_number": 0},
{"document_id": "doc_456", "chunk_number": 2}
]
"""
chunks = self._client.batch_get_chunks(sources)
return [chunk.model_dump() for chunk in chunks]
def create_cache(
self,
@ -265,7 +297,12 @@ if __name__ == "__main__":
# Print welcome message
print("\nDataBridge CLI ready to use. The 'db' object is available with all SDK methods.")
print("Example: db.ingest_text('hello world')")
print("Examples:")
print(" db.ingest_text('hello world')")
print(" db.query('what are the key findings?')")
print(" db.batch_get_documents(['doc_id1', 'doc_id2'])")
print(" db.batch_get_chunks([{'document_id': 'doc_123', 'chunk_number': 0}])")
print(" result = db.query('how to use this API?'); print(result['sources'])")
print("Type help(db) for documentation.")
# Start the shell