mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add completion sources and batch retrieval for docs and chunks (#51)
This commit is contained in:
parent
8c77da3708
commit
38683df0f3
34
core/api.py
34
core/api.py
@ -3,7 +3,6 @@ from datetime import datetime, UTC, timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from core.models.completion import CompletionResponse
|
||||
from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
@ -12,6 +11,7 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from core.completion.openai_completion import OpenAICompletionModel
|
||||
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest
|
||||
from core.models.completion import ChunkSource, CompletionResponse
|
||||
from core.models.documents import Document, DocumentResult, ChunkResult
|
||||
from core.models.auth import AuthContext, EntityType
|
||||
from core.parser.databridge_parser import DatabridgeParser
|
||||
@ -398,6 +398,38 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch/documents", response_model=List[Document])
|
||||
async def batch_get_documents(document_ids: List[str], auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve multiple documents by their IDs in a single batch operation."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="batch_get_documents",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"document_count": len(document_ids),
|
||||
},
|
||||
):
|
||||
return await document_service.batch_retrieve_documents(document_ids, auth)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch/chunks", response_model=List[ChunkResult])
|
||||
async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve specific chunks by their document ID and chunk number in a single batch operation."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="batch_get_chunks",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"chunk_count": len(chunk_ids),
|
||||
},
|
||||
):
|
||||
return await document_service.batch_retrieve_chunks(chunk_ids, auth)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/query", response_model=CompletionResponse)
|
||||
|
@ -1,8 +1,5 @@
|
||||
from core.completion.base_completion import (
|
||||
BaseCompletionModel,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
)
|
||||
from core.completion.base_completion import BaseCompletionModel
|
||||
from core.models.completion import CompletionRequest, CompletionResponse
|
||||
from ollama import AsyncClient
|
||||
|
||||
BASE_64_PREFIX = "data:image/png;base64,"
|
||||
|
@ -1,4 +1,5 @@
|
||||
from .base_completion import BaseCompletionModel, CompletionRequest, CompletionResponse
|
||||
from .base_completion import BaseCompletionModel
|
||||
from core.models.completion import CompletionRequest, CompletionResponse
|
||||
|
||||
|
||||
class OpenAICompletionModel(BaseCompletionModel):
|
||||
|
@ -23,6 +23,21 @@ class BaseDatabase(ABC):
|
||||
Returns: Document if found and accessible, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_documents(
|
||||
|
@ -79,6 +79,49 @@ class MongoDatabase(BaseDatabase):
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document metadata: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
"""
|
||||
try:
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query documents with both document IDs and access check in a single query
|
||||
query = {
|
||||
"$and": [
|
||||
{"external_id": {"$in": document_ids}},
|
||||
access_filter
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
|
||||
|
||||
# Execute batch query
|
||||
cursor = self.collection.find(query)
|
||||
|
||||
documents = []
|
||||
async for doc_dict in cursor:
|
||||
documents.append(Document(**doc_dict))
|
||||
|
||||
logger.info(f"Found {len(documents)} documents in batch retrieval")
|
||||
return documents
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error batch retrieving documents: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
|
@ -149,10 +149,67 @@ class PostgresDatabase(BaseDatabase):
|
||||
}
|
||||
return Document(**doc_dict)
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document metadata: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
"""
|
||||
try:
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query documents with both document IDs and access check in a single query
|
||||
query = (
|
||||
select(DocumentModel)
|
||||
.where(DocumentModel.external_id.in_(document_ids))
|
||||
.where(text(f"({access_filter})"))
|
||||
)
|
||||
|
||||
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
|
||||
|
||||
# Execute batch query
|
||||
result = await session.execute(query)
|
||||
doc_models = result.scalars().all()
|
||||
|
||||
documents = []
|
||||
for doc_model in doc_models:
|
||||
# Convert doc_metadata back to metadata
|
||||
doc_dict = {
|
||||
"external_id": doc_model.external_id,
|
||||
"owner": doc_model.owner,
|
||||
"content_type": doc_model.content_type,
|
||||
"filename": doc_model.filename,
|
||||
"metadata": doc_model.doc_metadata,
|
||||
"storage_info": doc_model.storage_info,
|
||||
"system_metadata": doc_model.system_metadata,
|
||||
"additional_metadata": doc_model.additional_metadata,
|
||||
"access_control": doc_model.access_control,
|
||||
"chunk_ids": doc_model.chunk_ids,
|
||||
}
|
||||
documents.append(Document(**doc_dict))
|
||||
|
||||
logger.info(f"Found {len(documents)} documents in batch retrieval")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch retrieving documents: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
|
@ -2,12 +2,20 @@ from pydantic import BaseModel
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class ChunkSource(BaseModel):
|
||||
"""Source information for a chunk used in completion"""
|
||||
|
||||
document_id: str
|
||||
chunk_number: int
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion generation"""
|
||||
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
finish_reason: Optional[str] = None
|
||||
sources: List[ChunkSource] = []
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
|
@ -17,7 +17,7 @@ from core.vector_store.base_vector_store import BaseVectorStore
|
||||
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
||||
from core.parser.base_parser import BaseParser
|
||||
from core.completion.base_completion import BaseCompletionModel
|
||||
from core.completion.base_completion import CompletionRequest, CompletionResponse
|
||||
from core.models.completion import CompletionRequest, CompletionResponse, ChunkSource
|
||||
import logging
|
||||
from core.reranker.base_reranker import BaseReranker
|
||||
from core.config import get_settings
|
||||
@ -148,6 +148,77 @@ class DocumentService:
|
||||
documents = list(results.values())
|
||||
logger.info(f"Returning {len(documents)} document results")
|
||||
return documents
|
||||
|
||||
async def batch_retrieve_documents(
|
||||
self,
|
||||
document_ids: List[str],
|
||||
auth: AuthContext
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of Document objects that user has access to
|
||||
"""
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Use the database's batch retrieval method
|
||||
documents = await self.db.get_documents_by_id(document_ids, auth)
|
||||
logger.info(f"Batch retrieved {len(documents)} documents out of {len(document_ids)} requested")
|
||||
return documents
|
||||
|
||||
async def batch_retrieve_chunks(
|
||||
self,
|
||||
chunk_ids: List[ChunkSource],
|
||||
auth: AuthContext
|
||||
) -> List[ChunkResult]:
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
|
||||
Args:
|
||||
chunk_ids: List of ChunkSource objects with document_id and chunk_number
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of ChunkResult objects
|
||||
"""
|
||||
if not chunk_ids:
|
||||
return []
|
||||
|
||||
# Collect unique document IDs to check authorization in a single query
|
||||
doc_ids = list({source.document_id for source in chunk_ids})
|
||||
|
||||
# Find authorized documents in a single query
|
||||
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth)
|
||||
authorized_doc_ids = {doc.external_id for doc in authorized_docs}
|
||||
|
||||
# Filter sources to only include authorized documents
|
||||
authorized_sources = [
|
||||
source for source in chunk_ids
|
||||
if source.document_id in authorized_doc_ids
|
||||
]
|
||||
|
||||
if not authorized_sources:
|
||||
return []
|
||||
|
||||
# Create list of (document_id, chunk_number) tuples for vector store query
|
||||
chunk_identifiers = [
|
||||
(source.document_id, source.chunk_number)
|
||||
for source in authorized_sources
|
||||
]
|
||||
|
||||
# Retrieve the chunks from vector store in a single query
|
||||
chunks = await self.vector_store.get_chunks_by_id(chunk_identifiers)
|
||||
|
||||
# Convert to chunk results
|
||||
results = await self._create_chunk_results(auth, chunks)
|
||||
logger.info(f"Batch retrieved {len(results)} chunks out of {len(chunk_ids)} requested")
|
||||
return results
|
||||
|
||||
async def query(
|
||||
self,
|
||||
@ -169,6 +240,12 @@ class DocumentService:
|
||||
documents = await self._create_document_results(auth, chunks)
|
||||
|
||||
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
|
||||
|
||||
# Collect sources information
|
||||
sources = [
|
||||
ChunkSource(document_id=chunk.document_id, chunk_number=chunk.chunk_number)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
# Generate completion
|
||||
request = CompletionRequest(
|
||||
@ -179,6 +256,10 @@ class DocumentService:
|
||||
)
|
||||
|
||||
response = await self.completion_model.complete(request)
|
||||
|
||||
# Add sources information at the document service level
|
||||
response.sources = sources
|
||||
|
||||
return response
|
||||
|
||||
async def ingest_text(
|
||||
|
@ -18,3 +18,19 @@ class BaseVectorStore(ABC):
|
||||
) -> List[DocumentChunk]:
|
||||
"""Find similar chunks"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_chunks_by_id(
|
||||
self,
|
||||
chunk_identifiers: List[Tuple[str, int]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Retrieve specific chunks by document ID and chunk number.
|
||||
|
||||
Args:
|
||||
chunk_identifiers: List of (document_id, chunk_number) tuples
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects
|
||||
"""
|
||||
pass
|
||||
|
@ -132,3 +132,52 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
logger.error(f"MongoDB error: {e._message}")
|
||||
logger.error(f"Error querying similar chunks: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def get_chunks_by_id(
|
||||
self,
|
||||
chunk_identifiers: List[Tuple[str, int]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Retrieve specific chunks by document ID and chunk number in a single database query.
|
||||
|
||||
Args:
|
||||
chunk_identifiers: List of (document_id, chunk_number) tuples
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects
|
||||
"""
|
||||
try:
|
||||
if not chunk_identifiers:
|
||||
return []
|
||||
|
||||
# Create a query with $or to find multiple chunks in a single query
|
||||
query = {"$or": []}
|
||||
for doc_id, chunk_num in chunk_identifiers:
|
||||
query["$or"].append({
|
||||
"document_id": doc_id,
|
||||
"chunk_number": chunk_num
|
||||
})
|
||||
|
||||
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
|
||||
|
||||
# Find all matching chunks in a single database query
|
||||
cursor = self.collection.find(query)
|
||||
chunks = []
|
||||
|
||||
async for result in cursor:
|
||||
chunk = DocumentChunk(
|
||||
document_id=result["document_id"],
|
||||
chunk_number=result["chunk_number"],
|
||||
content=result["content"],
|
||||
embedding=[], # Don't send embeddings back
|
||||
metadata=result.get("metadata", {}),
|
||||
score=0.0, # No relevance score for direct retrieval
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
|
||||
return chunks
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving chunks by ID: {str(e)}")
|
||||
return []
|
||||
|
@ -269,6 +269,62 @@ class MultiVectorStore(BaseVectorStore):
|
||||
# raise e
|
||||
# return []
|
||||
|
||||
async def get_chunks_by_id(
|
||||
self,
|
||||
chunk_identifiers: List[Tuple[str, int]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Retrieve specific chunks by document ID and chunk number in a single database query.
|
||||
|
||||
Args:
|
||||
chunk_identifiers: List of (document_id, chunk_number) tuples
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects
|
||||
"""
|
||||
# try:
|
||||
if not chunk_identifiers:
|
||||
return []
|
||||
|
||||
# Construct the WHERE clause with OR conditions
|
||||
conditions = []
|
||||
for doc_id, chunk_num in chunk_identifiers:
|
||||
conditions.append(f"(document_id = '{doc_id}' AND chunk_number = {chunk_num})")
|
||||
|
||||
where_clause = " OR ".join(conditions)
|
||||
|
||||
# Build and execute query
|
||||
query = f"""
|
||||
SELECT document_id, chunk_number, content, chunk_metadata
|
||||
FROM multi_vector_embeddings
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
|
||||
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks from multi-vector store")
|
||||
|
||||
result = self.conn.execute(query).fetchall()
|
||||
|
||||
# Convert to DocumentChunks
|
||||
chunks = []
|
||||
for row in result:
|
||||
try:
|
||||
metadata = eval(row[3]) if row[3] else {}
|
||||
except (ValueError, SyntaxError):
|
||||
metadata = {}
|
||||
|
||||
chunk = DocumentChunk(
|
||||
document_id=row[0],
|
||||
chunk_number=row[1],
|
||||
content=row[2],
|
||||
embedding=[], # Don't send embeddings back
|
||||
metadata=metadata,
|
||||
score=0.0, # No relevance score for direct retrieval
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"Found {len(chunks)} chunks in batch retrieval from multi-vector store")
|
||||
return chunks
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
if self.conn:
|
||||
|
@ -178,3 +178,66 @@ class PGVectorStore(BaseVectorStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying similar chunks: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_chunks_by_id(
|
||||
self,
|
||||
chunk_identifiers: List[Tuple[str, int]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Retrieve specific chunks by document ID and chunk number in a single database query.
|
||||
|
||||
Args:
|
||||
chunk_identifiers: List of (document_id, chunk_number) tuples
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects
|
||||
"""
|
||||
try:
|
||||
if not chunk_identifiers:
|
||||
return []
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Create a list of OR conditions for the query
|
||||
conditions = []
|
||||
for doc_id, chunk_num in chunk_identifiers:
|
||||
conditions.append(
|
||||
text(f"(document_id = '{doc_id}' AND chunk_number = {chunk_num})")
|
||||
)
|
||||
|
||||
# Join conditions with OR
|
||||
or_condition = text(" OR ".join(f"({condition.text})" for condition in conditions))
|
||||
|
||||
# Build query to find all matching chunks in a single query
|
||||
query = select(VectorEmbedding).where(or_condition)
|
||||
|
||||
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
|
||||
|
||||
# Execute query
|
||||
result = await session.execute(query)
|
||||
chunk_models = result.scalars().all()
|
||||
|
||||
# Convert to DocumentChunk objects
|
||||
chunks = []
|
||||
for chunk_model in chunk_models:
|
||||
# Convert stored metadata string back to dict
|
||||
try:
|
||||
metadata = eval(chunk_model.chunk_metadata) if chunk_model.chunk_metadata else {}
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
chunk = DocumentChunk(
|
||||
document_id=chunk_model.document_id,
|
||||
chunk_number=chunk_model.chunk_number,
|
||||
content=chunk_model.content,
|
||||
embedding=[], # Don't send embeddings back
|
||||
metadata=metadata,
|
||||
score=0.0, # No relevance score for direct retrieval
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving chunks by ID: {str(e)}")
|
||||
return []
|
||||
|
@ -14,6 +14,7 @@ from .models import (
|
||||
DocumentResult,
|
||||
CompletionResponse,
|
||||
IngestTextRequest,
|
||||
ChunkSource,
|
||||
)
|
||||
from .rules import Rule
|
||||
|
||||
@ -452,6 +453,105 @@ class AsyncDataBridge:
|
||||
"""
|
||||
response = await self._request("GET", f"documents/{document_id}")
|
||||
return Document(**response)
|
||||
|
||||
async def batch_get_documents(self, document_ids: List[str]) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List[Document]: List of document metadata for found documents
|
||||
|
||||
Example:
|
||||
```python
|
||||
docs = await db.batch_get_documents(["doc_123", "doc_456", "doc_789"])
|
||||
for doc in docs:
|
||||
print(f"Document {doc.external_id}: {doc.metadata.get('title')}")
|
||||
```
|
||||
"""
|
||||
response = await self._request("POST", "batch/documents", data=document_ids)
|
||||
return [Document(**doc) for doc in response]
|
||||
|
||||
async def batch_get_chunks(self, sources: List[Union[ChunkSource, Dict[str, Any]]]) -> List[FinalChunkResult]:
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
|
||||
Args:
|
||||
sources: List of ChunkSource objects or dictionaries with document_id and chunk_number
|
||||
|
||||
Returns:
|
||||
List[FinalChunkResult]: List of chunk results
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Using dictionaries
|
||||
sources = [
|
||||
{"document_id": "doc_123", "chunk_number": 0},
|
||||
{"document_id": "doc_456", "chunk_number": 2}
|
||||
]
|
||||
|
||||
# Or using ChunkSource objects
|
||||
from databridge.models import ChunkSource
|
||||
sources = [
|
||||
ChunkSource(document_id="doc_123", chunk_number=0),
|
||||
ChunkSource(document_id="doc_456", chunk_number=2)
|
||||
]
|
||||
|
||||
chunks = await db.batch_get_chunks(sources)
|
||||
for chunk in chunks:
|
||||
print(f"Chunk from {chunk.document_id}, number {chunk.chunk_number}: {chunk.content[:50]}...")
|
||||
```
|
||||
"""
|
||||
# Convert to list of dictionaries if needed
|
||||
source_dicts = []
|
||||
for source in sources:
|
||||
if isinstance(source, dict):
|
||||
source_dicts.append(source)
|
||||
else:
|
||||
source_dicts.append(source.model_dump())
|
||||
|
||||
response = await self._request("POST", "batch/chunks", data=source_dicts)
|
||||
chunks = [ChunkResult(**r) for r in response]
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk.metadata.get("is_image"):
|
||||
try:
|
||||
# Handle data URI format "data:image/png;base64,..."
|
||||
content = chunk.content
|
||||
if content.startswith("data:"):
|
||||
# Extract the base64 part after the comma
|
||||
content = content.split(",", 1)[1]
|
||||
|
||||
# Now decode the base64 string
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
image_bytes = base64.b64decode(content)
|
||||
content = Image.open(io.BytesIO(image_bytes))
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
# Fall back to using the content as text
|
||||
content = chunk.content
|
||||
else:
|
||||
content = chunk.content
|
||||
|
||||
final_chunks.append(
|
||||
FinalChunkResult(
|
||||
content=content,
|
||||
score=chunk.score,
|
||||
document_id=chunk.document_id,
|
||||
chunk_number=chunk.chunk_number,
|
||||
metadata=chunk.metadata,
|
||||
content_type=chunk.content_type,
|
||||
filename=chunk.filename,
|
||||
download_url=chunk.download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return final_chunks
|
||||
|
||||
async def create_cache(
|
||||
self,
|
||||
|
@ -59,11 +59,21 @@ class DocumentResult(BaseModel):
|
||||
content: DocumentContent = Field(..., description="Document content or URL")
|
||||
|
||||
|
||||
class ChunkSource(BaseModel):
|
||||
"""Source information for a chunk used in completion"""
|
||||
|
||||
document_id: str = Field(..., description="ID of the source document")
|
||||
chunk_number: int = Field(..., description="Chunk number within the document")
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Completion response model"""
|
||||
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
sources: List[ChunkSource] = Field(
|
||||
default_factory=list, description="Sources of chunks used in the completion"
|
||||
)
|
||||
|
||||
|
||||
class IngestTextRequest(BaseModel):
|
||||
|
@ -12,7 +12,7 @@ import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
from .models import Document, ChunkResult, DocumentResult, CompletionResponse, IngestTextRequest
|
||||
from .models import Document, ChunkResult, DocumentResult, CompletionResponse, IngestTextRequest, ChunkSource
|
||||
from .rules import Rule
|
||||
|
||||
# Type alias for rules
|
||||
@ -487,6 +487,102 @@ class DataBridge:
|
||||
"""
|
||||
response = self._request("GET", f"documents/{document_id}")
|
||||
return Document(**response)
|
||||
|
||||
def batch_get_documents(self, document_ids: List[str]) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List[Document]: List of document metadata for found documents
|
||||
|
||||
Example:
|
||||
```python
|
||||
docs = db.batch_get_documents(["doc_123", "doc_456", "doc_789"])
|
||||
for doc in docs:
|
||||
print(f"Document {doc.external_id}: {doc.metadata.get('title')}")
|
||||
```
|
||||
"""
|
||||
response = self._request("POST", "batch/documents", data=document_ids)
|
||||
return [Document(**doc) for doc in response]
|
||||
|
||||
def batch_get_chunks(self, sources: List[Union[ChunkSource, Dict[str, Any]]]) -> List[FinalChunkResult]:
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
|
||||
Args:
|
||||
sources: List of ChunkSource objects or dictionaries with document_id and chunk_number
|
||||
|
||||
Returns:
|
||||
List[FinalChunkResult]: List of chunk results
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Using dictionaries
|
||||
sources = [
|
||||
{"document_id": "doc_123", "chunk_number": 0},
|
||||
{"document_id": "doc_456", "chunk_number": 2}
|
||||
]
|
||||
|
||||
# Or using ChunkSource objects
|
||||
from databridge.models import ChunkSource
|
||||
sources = [
|
||||
ChunkSource(document_id="doc_123", chunk_number=0),
|
||||
ChunkSource(document_id="doc_456", chunk_number=2)
|
||||
]
|
||||
|
||||
chunks = db.batch_get_chunks(sources)
|
||||
for chunk in chunks:
|
||||
print(f"Chunk from {chunk.document_id}, number {chunk.chunk_number}: {chunk.content[:50]}...")
|
||||
```
|
||||
"""
|
||||
# Convert to list of dictionaries if needed
|
||||
source_dicts = []
|
||||
for source in sources:
|
||||
if isinstance(source, dict):
|
||||
source_dicts.append(source)
|
||||
else:
|
||||
source_dicts.append(source.model_dump())
|
||||
|
||||
response = self._request("POST", "batch/chunks", data=source_dicts)
|
||||
chunks = [ChunkResult(**r) for r in response]
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk.metadata.get("is_image"):
|
||||
try:
|
||||
# Handle data URI format "data:image/png;base64,..."
|
||||
content = chunk.content
|
||||
if content.startswith("data:"):
|
||||
# Extract the base64 part after the comma
|
||||
content = content.split(",", 1)[1]
|
||||
|
||||
# Now decode the base64 string
|
||||
image_bytes = base64.b64decode(content)
|
||||
content = Image.open(io.BytesIO(image_bytes))
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
# Fall back to using the content as text
|
||||
content = chunk.content
|
||||
else:
|
||||
content = chunk.content
|
||||
|
||||
final_chunks.append(
|
||||
FinalChunkResult(
|
||||
content=content,
|
||||
score=chunk.score,
|
||||
document_id=chunk.document_id,
|
||||
chunk_number=chunk.chunk_number,
|
||||
metadata=chunk.metadata,
|
||||
content_type=chunk.content_type,
|
||||
filename=chunk.filename,
|
||||
download_url=chunk.download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
|
39
shell.py
39
shell.py
@ -190,6 +190,38 @@ class DB:
|
||||
"""Get document metadata by ID"""
|
||||
doc = self._client.get_document(document_id)
|
||||
return doc.model_dump()
|
||||
|
||||
def batch_get_documents(self, document_ids: List[str]) -> List[dict]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
|
||||
Returns:
|
||||
List of document metadata
|
||||
"""
|
||||
docs = self._client.batch_get_documents(document_ids)
|
||||
return [doc.model_dump() for doc in docs]
|
||||
|
||||
def batch_get_chunks(self, sources: List[dict]) -> List[dict]:
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
|
||||
Args:
|
||||
sources: List of dictionaries with document_id and chunk_number fields
|
||||
|
||||
Returns:
|
||||
List of chunk results
|
||||
|
||||
Example:
|
||||
sources = [
|
||||
{"document_id": "doc_123", "chunk_number": 0},
|
||||
{"document_id": "doc_456", "chunk_number": 2}
|
||||
]
|
||||
"""
|
||||
chunks = self._client.batch_get_chunks(sources)
|
||||
return [chunk.model_dump() for chunk in chunks]
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
@ -265,7 +297,12 @@ if __name__ == "__main__":
|
||||
|
||||
# Print welcome message
|
||||
print("\nDataBridge CLI ready to use. The 'db' object is available with all SDK methods.")
|
||||
print("Example: db.ingest_text('hello world')")
|
||||
print("Examples:")
|
||||
print(" db.ingest_text('hello world')")
|
||||
print(" db.query('what are the key findings?')")
|
||||
print(" db.batch_get_documents(['doc_id1', 'doc_id2'])")
|
||||
print(" db.batch_get_chunks([{'document_id': 'doc_123', 'chunk_number': 0}])")
|
||||
print(" result = db.query('how to use this API?'); print(result['sources'])")
|
||||
print("Type help(db) for documentation.")
|
||||
|
||||
# Start the shell
|
||||
|
Loading…
x
Reference in New Issue
Block a user