Add completions API (#3)

This commit is contained in:
Adityavardhan Agrawal 2024-12-26 08:52:25 -05:00 committed by GitHub
parent a404154650
commit 03345dcc07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 599 additions and 317 deletions

View File

@ -3,7 +3,7 @@ default_region = "us-east-2"
default_bucket_name = "databridge-s3-storage"
[mongodb]
database_name = "databridge"
database_name = "DataBridgeTest"
documents_collection = "documents"
chunks_collection = "document_chunks"
@ -13,6 +13,7 @@ index_name = "vector_index"
[model]
embedding_model = "text-embedding-3-small"
completion_model = "gpt-3.5-turbo"
[document_processing]
chunk_size = 1000

View File

@ -12,13 +12,14 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
from core.models.request import IngestTextRequest, QueryRequest
from core.models.request import IngestTextRequest, RetrieveRequest, CompletionQueryRequest
from core.models.documents import (
Document,
DocumentResult,
ChunkResult
)
from core.models.auth import AuthContext, EntityType
from core.completion.base_completion import CompletionResponse
from core.services.document_service import DocumentService
from core.config import get_settings
from core.database.mongo_database import MongoDatabase
@ -26,7 +27,7 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.storage.s3_storage import S3Storage
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from core.completion.openai_completion import OpenAICompletionModel
# Initialize FastAPI app
app = FastAPI(title="DataBridge API")
@ -76,13 +77,18 @@ embedding_model = OpenAIEmbeddingModel(
model_name=settings.EMBEDDING_MODEL
)
completion_model = OpenAICompletionModel(
model_name=settings.COMPLETION_MODEL
)
# Initialize document service
document_service = DocumentService(
database=database,
vector_store=vector_store,
storage=storage,
parser=parser,
embedding_model=embedding_model
embedding_model=embedding_model,
completion_model=completion_model
)
@ -150,13 +156,51 @@ async def ingest_file(
raise HTTPException(400, "Invalid metadata JSON")
@app.post("/query", response_model=Union[List[ChunkResult], List[DocumentResult]])
async def query_documents(
request: QueryRequest,
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(
request: RetrieveRequest,
auth: AuthContext = Depends(verify_token)
):
"""Query documents with specified return type."""
return await document_service.query(request, auth)
"""Retrieve relevant chunks."""
return await document_service.retrieve_chunks(
request.query,
auth,
request.filters,
request.k,
request.min_score
)
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(
request: RetrieveRequest,
auth: AuthContext = Depends(verify_token)
):
"""Retrieve relevant documents."""
return await document_service.retrieve_docs(
request.query,
auth,
request.filters,
request.k,
request.min_score
)
@app.post("/query", response_model=CompletionResponse)
async def query_completion(
request: CompletionQueryRequest,
auth: AuthContext = Depends(verify_token)
):
"""Generate completion using relevant chunks as context."""
return await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature
)
@app.get("/documents", response_model=List[Document])

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict
from pydantic import BaseModel
class CompletionResponse(BaseModel):
"""Response from completion generation"""
completion: str
usage: Dict[str, int]
class CompletionRequest(BaseModel):
"""Request for completion generation"""
query: str
context_chunks: List[str]
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.7
class BaseCompletionModel(ABC):
"""Base class for completion models"""
@abstractmethod
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion from query and context"""
pass

View File

@ -0,0 +1,37 @@
from .base_completion import BaseCompletionModel, CompletionRequest, CompletionResponse
class OpenAICompletionModel(BaseCompletionModel):
"""OpenAI completion model implementation"""
def __init__(self, model_name: str):
self.model_name = model_name
# Import here to avoid dependency if not using OpenAI
from openai import AsyncOpenAI
self.client = AsyncOpenAI()
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using OpenAI API"""
# Construct prompt with context
context = "\n\n".join(request.context_chunks)
messages = [
{"role": "system", "content": "You are a helpful assistant. Use the provided context to answer questions accurately."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {request.query}"}
]
# Call OpenAI API
response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
)
return CompletionResponse(
completion=response.choices[0].message.content,
usage={
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
)

View File

@ -1,104 +1,66 @@
from typing import Dict, Any
from pydantic import Field
from pydantic_settings import BaseSettings
from functools import lru_cache
import tomli
from dotenv import load_dotenv, find_dotenv
from dotenv import load_dotenv
def load_toml_config() -> Dict[Any, Any]:
"""Load configuration from config.toml file."""
with open("config.toml", "rb") as f:
return tomli.load(f)
class Settings(BaseSettings):
"""DataBridge configuration settings."""
# MongoDB settings
# Required environment variables
MONGODB_URI: str = Field(..., env="MONGODB_URI")
DATABRIDGE_DB: str = Field(None)
# Collection names
DOCUMENTS_COLLECTION: str = Field(None)
CHUNKS_COLLECTION: str = Field(None)
# Vector search settings
VECTOR_INDEX_NAME: str = Field(None)
# API Keys
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
# Optional API keys for alternative models
ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY")
COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY")
VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY")
# Model settings
EMBEDDING_MODEL: str = Field("text-embedding-3-small")
# Document processing settings
CHUNK_SIZE: int = Field(1000)
CHUNK_OVERLAP: int = Field(200)
DEFAULT_K: int = Field(4)
# Storage settings
AWS_ACCESS_KEY: str = Field(..., env="AWS_ACCESS_KEY")
AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY")
AWS_REGION: str = Field(None)
S3_BUCKET: str = Field(None)
# Auth settings
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
JWT_ALGORITHM: str = Field("HS256")
# Server settings
HOST: str = Field("localhost")
PORT: int = Field(8000)
RELOAD: bool = Field(False)
class Config:
env_file = ".env"
case_sensitive = True
extra = "allow"
# Values from config.toml with defaults
AWS_REGION: str = "us-east-2"
S3_BUCKET: str = "databridge-s3-storage"
DATABRIDGE_DB: str = "databridge"
DOCUMENTS_COLLECTION: str = "documents"
CHUNKS_COLLECTION: str = "document_chunks"
VECTOR_INDEX_NAME: str = "vector_index"
VECTOR_DIMENSIONS: int = 1536
EMBEDDING_MODEL: str = "text-embedding-3-small"
COMPLETION_MODEL: str = "gpt-3.5-turbo"
CHUNK_SIZE: int = 1000
CHUNK_OVERLAP: int = 200
DEFAULT_K: int = 4
HOST: str = "localhost"
PORT: int = 8000
RELOAD: bool = False
JWT_ALGORITHM: str = "HS256"
def __init__(self, **kwargs):
# Force reload of environment variables
load_dotenv(find_dotenv(), override=True)
config = load_toml_config()
# Set values from config.toml
kwargs.update({
# MongoDB settings
"DATABRIDGE_DB": config["mongodb"]["database_name"],
"DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"],
"CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"],
"VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"],
# AWS settings
"AWS_REGION": config["aws"]["default_region"],
"S3_BUCKET": config["aws"]["default_bucket_name"],
# Model settings
"EMBEDDING_MODEL": config["model"]["embedding_model"],
# Document processing settings
"CHUNK_SIZE": config["document_processing"]["chunk_size"],
"CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"],
"DEFAULT_K": config["document_processing"]["default_k"],
# Server settings
"HOST": config["server"]["host"],
"PORT": config["server"]["port"],
"RELOAD": config["server"]["reload"],
# Auth settings
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
})
super().__init__(**kwargs)
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
load_dotenv()
# Load config.toml
with open("config.toml", "rb") as f:
config = tomli.load(f)
# Map config.toml values to settings
settings_dict = {
"AWS_REGION": config["aws"]["default_region"],
"S3_BUCKET": config["aws"]["default_bucket_name"],
"DATABRIDGE_DB": config["mongodb"]["database_name"],
"DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"],
"CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"],
"VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"],
"VECTOR_DIMENSIONS": config["mongodb"]["vector"]["dimensions"],
"EMBEDDING_MODEL": config["model"]["embedding_model"],
"COMPLETION_MODEL": config["model"]["completion_model"],
"CHUNK_SIZE": config["document_processing"]["chunk_size"],
"CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"],
"DEFAULT_K": config["document_processing"]["default_k"],
"HOST": config["server"]["host"],
"PORT": config["server"]["port"],
"RELOAD": config["server"]["reload"],
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
}
return Settings(**settings_dict)

View File

@ -1,6 +1,5 @@
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from .documents import QueryReturnType
class IngestTextRequest(BaseModel):
@ -9,10 +8,15 @@ class IngestTextRequest(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
class QueryRequest(BaseModel):
"""Query request model with validation"""
class RetrieveRequest(BaseModel):
"""Base retrieve request model"""
query: str = Field(..., min_length=1)
return_type: QueryReturnType = QueryReturnType.CHUNKS
filters: Optional[Dict[str, Any]] = None
k: int = Field(default=4, gt=0)
min_score: float = Field(default=0.0)
class CompletionQueryRequest(RetrieveRequest):
"""Request model for completion generation"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None

View File

@ -1,9 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, Any
class BasePlanner(ABC):
@abstractmethod
def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]:
"""Create execution plan for retrieval"""
pass

View File

@ -1,17 +0,0 @@
from typing import Dict, Any
from .base_planner import BasePlanner
class SimpleRAGPlanner(BasePlanner):
def __init__(self, default_k: int = 3):
self.default_k = default_k
def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]:
"""Create a simple retrieval plan."""
return {
"strategy": "simple_retrieval",
"k": kwargs.get("k", self.default_k),
"query": query,
"filters": kwargs.get("filters", {}),
"min_score": kwargs.get("min_score", 0.0)
}

View File

@ -1,24 +1,18 @@
from typing import Any, Dict, List, Union
import logging
from fastapi import UploadFile
import base64
from typing import Dict, Any, List, Optional
from fastapi import UploadFile
from core.models.request import IngestTextRequest
from ..models.documents import Document, DocumentChunk, ChunkResult, DocumentContent, DocumentResult
from ..models.auth import AuthContext
from core.database.base_database import BaseDatabase
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
from core.models.request import IngestTextRequest, QueryRequest
from core.parser.base_parser import BaseParser
from core.storage.base_storage import BaseStorage
from core.vector_store.base_vector_store import BaseVectorStore
from ..models.documents import (
Document,
DocumentChunk,
ChunkResult,
DocumentContent,
DocumentResult,
QueryReturnType
)
from ..models.auth import AuthContext
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
from core.parser.base_parser import BaseParser
from core.completion.base_completion import BaseCompletionModel
from core.completion.base_completion import CompletionRequest, CompletionResponse
import logging
logger = logging.getLogger(__name__)
@ -30,13 +24,91 @@ class DocumentService:
vector_store: BaseVectorStore,
storage: BaseStorage,
parser: BaseParser,
embedding_model: BaseEmbeddingModel
embedding_model: BaseEmbeddingModel,
completion_model: BaseCompletionModel
):
self.db = database
self.vector_store = vector_store
self.storage = storage
self.parser = parser
self.embedding_model = embedding_model
self.completion_model = completion_model
async def retrieve_chunks(
self,
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> List[ChunkResult]:
"""Retrieve relevant chunks."""
# Get embedding for query
query_embedding = await self.embedding_model.embed_for_query(query)
logger.info("Generated query embedding")
# Find authorized documents
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters)
if not doc_ids:
logger.info("No authorized documents found")
return []
logger.info(f"Found {len(doc_ids)} authorized documents")
# Search chunks with vector similarity
chunks = await self.vector_store.query_similar(
query_embedding,
k=k,
doc_ids=doc_ids,
)
logger.info(f"Found {len(chunks)} similar chunks")
# Create and return chunk results
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results")
return results
async def retrieve_docs(
self,
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> List[DocumentResult]:
"""Retrieve relevant documents."""
# Get chunks first
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
# Convert to document results
results = await self._create_document_results(auth, chunks)
logger.info(f"Returning {len(results)} document results")
return results
async def query(
self,
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None
) -> CompletionResponse:
"""Generate completion using relevant chunks as context."""
# Get relevant chunks
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
chunk_contents = [chunk.content for chunk in chunks]
# Generate completion
request = CompletionRequest(
query=query,
context_chunks=chunk_contents,
max_tokens=max_tokens,
temperature=temperature
)
response = await self.completion_model.complete(request)
return response
async def ingest_text(
self,
@ -156,44 +228,6 @@ class DocumentService:
return doc
async def query(
self,
request: QueryRequest,
auth: AuthContext
) -> Union[List[ChunkResult], List[DocumentResult]]:
"""Query documents with specified return type."""
# TODO: k does not make sense for Documents, it's about chunks.
# We should also look into document ordering. Figure these out.
# 1. Get embedding for query
query_embedding = await self.embedding_model.embed_for_query(request.query)
logger.info("Generated query embedding")
# 2. Find authorized documents
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, request.filters)
if not doc_ids:
logger.info("No authorized documents found")
return []
logger.info(f"Found {len(doc_ids)} authorized documents")
# 3. Search chunks with vector similarity
chunks = await self.vector_store.query_similar(
query_embedding,
k=request.k,
doc_ids=doc_ids,
)
logger.info(f"Found {len(chunks)} similar chunks")
# 4. Return results in requested format
if request.return_type == QueryReturnType.CHUNKS:
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results")
return results
else:
results = await self._create_document_results(auth, chunks)
logger.info(f"Returning {len(results)} document results")
return results
def _create_chunk_objects(
self,
doc_id: str,

View File

@ -259,71 +259,6 @@ async def test_auth_insufficient_permissions(client: AsyncClient):
assert response.status_code == 403
@pytest.mark.asyncio
async def test_query_chunks(client: AsyncClient):
"""Test querying document chunks"""
# First ingest a document to query
doc_id = await test_ingest_text_document(
client,
content="The quick brown fox jumps over the lazy dog"
)
headers = create_auth_header()
# Sleep to allow time for document to be indexed
await asyncio.sleep(1)
response = await client.post(
"/query",
json={
"query": "jumping fox",
"return_type": "chunks",
"k": 1
},
headers=headers
)
logger.info(f"Query response: {response.json()}")
assert response.status_code == 200
results = list(response.json())
logger.info(f"Query results: {results}")
assert len(results) == 1
assert results[0]["score"] > 0.5
assert results[0]["document_id"] == doc_id
@pytest.mark.asyncio
async def test_query_documents(client: AsyncClient):
"""Test querying for full documents"""
# First ingest a document to query
content = (
"Headaches can significantly impact daily life and wellbeing. "
"Common triggers include stress, dehydration, and poor sleep habits. "
"While over-the-counter pain relievers may provide temporary relief, "
"it's important to identify and address the root causes. "
"Maintaining good health through proper nutrition, regular exercise, "
"and stress management can help prevent chronic headaches."
)
doc_id = await test_ingest_text_document(client, content=content)
headers = create_auth_header()
response = await client.post(
"/query",
json={
"query": "Headaches, dehydration",
"return_type": "documents",
"filters": {"test": True}
},
headers=headers
)
assert response.status_code == 200
results = list(response.json())
assert len(results) > 0
assert results[0]["document_id"] == doc_id
assert "score" in results[0]
assert "metadata" in results[0]
@pytest.mark.asyncio
async def test_list_documents(client: AsyncClient):
"""Test listing documents"""
@ -367,15 +302,139 @@ async def test_invalid_document_id(client: AsyncClient):
@pytest.mark.asyncio
async def test_invalid_query_params(client: AsyncClient):
"""Test error handling for invalid query parameters"""
async def test_retrieve_chunks(client: AsyncClient):
"""Test retrieving document chunks"""
# First ingest a document to search
doc_id = await test_ingest_text_document(
client,
content="The quick brown fox jumps over the lazy dog"
)
headers = create_auth_header()
# Sleep to allow time for document to be indexed
await asyncio.sleep(1)
response = await client.post(
"/retrieve/chunks",
json={
"query": "jumping fox",
"k": 1,
"min_score": 0.0
},
headers=headers
)
assert response.status_code == 200
results = list(response.json())
assert len(results) == 1
assert results[0]["score"] > 0.5
assert results[0]["document_id"] == doc_id
@pytest.mark.asyncio
async def test_retrieve_docs(client: AsyncClient):
"""Test retrieving full documents"""
# First ingest a document to search
content = (
"Headaches can significantly impact daily life and wellbeing. "
"Common triggers include stress, dehydration, and poor sleep habits. "
"While over-the-counter pain relievers may provide temporary relief, "
"it's important to identify and address the root causes. "
"Maintaining good health through proper nutrition, regular exercise, "
"and stress management can help prevent chronic headaches."
)
doc_id = await test_ingest_text_document(client, content=content)
headers = create_auth_header()
response = await client.post(
"/retrieve/docs",
json={
"query": "Headaches, dehydration",
"filters": {"test": True}
},
headers=headers
)
assert response.status_code == 200
results = list(response.json())
assert len(results) > 0
assert results[0]["document_id"] == doc_id
assert "score" in results[0]
assert "metadata" in results[0]
@pytest.mark.asyncio
async def test_query_completion(client: AsyncClient):
"""Test generating completions from context"""
# First ingest a document to use as context
content = (
"The benefits of exercise are numerous. Regular physical activity "
"can improve cardiovascular health, strengthen muscles, enhance mental "
"wellbeing, and help maintain a healthy weight. Studies show that "
"even moderate exercise like walking can significantly reduce the risk "
"of various health conditions."
)
await test_ingest_text_document(client, content=content)
headers = create_auth_header()
response = await client.post(
"/query",
json={
"query": "What are the main benefits of exercise?",
"k": 2,
"temperature": 0.7,
"max_tokens": 100
},
headers=headers
)
assert response.status_code == 200
result = response.json()
assert "completion" in result
assert "usage" in result
assert len(result["completion"]) > 0
@pytest.mark.asyncio
async def test_invalid_retrieve_params(client: AsyncClient):
"""Test error handling for invalid retrieve parameters"""
headers = create_auth_header()
# Test empty query
response = await client.post(
"/retrieve/chunks",
json={
"query": "", # Empty query
"k": 1
},
headers=headers
)
assert response.status_code == 422
# Test invalid k
response = await client.post(
"/retrieve/docs",
json={
"query": "test",
"k": -1 # Invalid k
},
headers=headers
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_invalid_completion_params(client: AsyncClient):
"""Test error handling for invalid completion parameters"""
headers = create_auth_header()
# Test empty query
response = await client.post(
"/query",
json={
"query": "", # Empty query
"temperature": 2.0 # Invalid temperature
},
headers=headers
)
assert response.status_code == 422

View File

@ -7,18 +7,18 @@ from urllib.parse import urlparse
import httpx
import jwt
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse
class AsyncDataBridge:
"""
DataBridge client for document operations.
Args:
uri (str): DataBridge URI in the format "databridge://<owner_id>:<token>@<host>"
timeout (int, optional): Request timeout in seconds. Defaults to 30.
is_local (bool, optional): Whether to connect to a local server. Defaults to False.
Examples:
```python
async with AsyncDataBridge("databridge://owner_id:token@api.databridge.ai") as db:
@ -27,7 +27,7 @@ class AsyncDataBridge:
"Sample content",
metadata={"category": "sample"}
)
# Query documents
results = await db.query("search query")
```
@ -52,7 +52,7 @@ class AsyncDataBridge:
# Split host and auth parts
auth, host = parsed.netloc.split('@')
self._owner_id, self._auth_token = auth.split(':')
# Set base URL
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
@ -68,7 +68,7 @@ class AsyncDataBridge:
) -> Dict[str, Any]:
"""Make authenticated HTTP request"""
headers = {"Authorization": f"Bearer {self._auth_token}"}
if not files:
headers["Content-Type"] = "application/json"
@ -90,14 +90,14 @@ class AsyncDataBridge:
) -> Document:
"""
Ingest a text document into DataBridge.
Args:
content: Text content to ingest
metadata: Optional metadata dictionary
Returns:
Document: Metadata of the ingested document
Example:
```python
doc = await db.ingest_text(
@ -130,16 +130,16 @@ class AsyncDataBridge:
) -> Document:
"""
Ingest a file document into DataBridge.
Args:
file: File to ingest (path string, bytes, file object, or Path)
filename: Name of the file
content_type: MIME type (optional, will be guessed if not provided)
metadata: Optional metadata dictionary
Returns:
Document: Metadata of the ingested document
Example:
```python
# From file path
@ -149,7 +149,7 @@ class AsyncDataBridge:
content_type="application/pdf",
metadata={"department": "research"}
)
# From file object
with open("document.pdf", "rb") as f:
doc = await db.ingest_file(f, "document.pdf")
@ -189,58 +189,125 @@ class AsyncDataBridge:
if isinstance(file, (str, Path)):
file_obj.close()
async def query(
async def retrieve_chunks(
self,
query: str,
return_type: str = "chunks",
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> Union[List[ChunkResult], List[DocumentResult]]:
) -> List[ChunkResult]:
"""
Query documents in DataBridge.
Search for relevant chunks.
Args:
query: Search query text
return_type: Type of results ("chunks" or "documents")
filters: Optional metadata filters
k: Number of results (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
Returns:
List[ChunkResult] or List[DocumentResult] depending on return_type
List[ChunkResult]
Example:
```python
# Query for chunks
chunks = await db.query(
chunks = await db.retrieve_chunks(
"What are the key findings?",
return_type="chunks",
filters={"department": "research"}
)
```
"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score
}
# Query for documents
docs = await db.query(
response = await self._request("POST", "retrieve/chunks", request)
return [ChunkResult(**r) for r in response]
async def retrieve_docs(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> List[DocumentResult]:
"""
Retrieve relevant documents.
Args:
query: Search query text
filters: Optional metadata filters
k: Number of results (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
Returns:
List[DocumentResult]
Example:
```python
docs = await db.retrieve_docs(
"machine learning",
return_type="documents",
k=5
)
```
"""
request = {
"query": query,
"return_type": return_type,
"filters": filters,
"k": k,
"min_score": min_score
}
response = await self._request("POST", "query", request)
if return_type == "chunks":
return [ChunkResult(**r) for r in response]
response = await self._request("POST", "retrieve/docs", request)
return [DocumentResult(**r) for r in response]
async def query(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> CompletionResponse:
"""
Generate completion using relevant chunks as context.
Args:
query: Query text
filters: Optional metadata filters
k: Number of chunks to use as context (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
max_tokens: Maximum tokens in completion
temperature: Model temperature
Returns:
CompletionResponse
Example:
```python
response = await db.query(
"What are the key findings about customer satisfaction?",
filters={"department": "research"},
temperature=0.7
)
print(response.completion)
```
"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score,
"max_tokens": max_tokens,
"temperature": temperature,
}
response = await self._request("POST", "query", request)
return CompletionResponse(**response)
async def list_documents(
self,
skip: int = 0,
@ -249,7 +316,7 @@ class AsyncDataBridge:
) -> List[Document]:
"""
List accessible documents.
Args:
skip: Number of documents to skip
limit: Maximum number of documents to return
@ -257,12 +324,12 @@ class AsyncDataBridge:
Returns:
List[Document]: List of accessible documents
Example:
```python
# Get first page
docs = await db.list_documents(limit=10)
# Get next page
next_page = await db.list_documents(skip=10, limit=10, filters={"department": "research"})
```
@ -276,13 +343,13 @@ class AsyncDataBridge:
async def get_document(self, document_id: str) -> Document:
"""
Get document metadata by ID.
Args:
document_id: ID of the document
Returns:
Document: Document metadata
Example:
```python
doc = await db.get_document("doc_123")

View File

@ -53,3 +53,9 @@ class DocumentResult(BaseModel):
document_id: str = Field(..., description="Document ID")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
content: DocumentContent = Field(..., description="Document content or URL")
class CompletionResponse(BaseModel):
"""Completion response model"""
completion: str
usage: Dict[str, int]

View File

@ -7,18 +7,18 @@ from urllib.parse import urlparse
import jwt
import requests
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse
class DataBridge:
"""
DataBridge client for document operations.
Args:
uri (str): DataBridge URI in the format "databridge://<owner_id>:<token>@<host>"
timeout (int, optional): Request timeout in seconds. Defaults to 30.
is_local (bool, optional): Whether connecting to local development server. Defaults to False.
Examples:
```python
with DataBridge("databridge://owner_id:token@api.databridge.ai") as db:
@ -27,7 +27,7 @@ class DataBridge:
"Sample content",
metadata={"category": "sample"}
)
# Query documents
results = db.query("search query")
```
@ -50,7 +50,7 @@ class DataBridge:
# Split host and auth parts
auth, host = parsed.netloc.split('@')
self._owner_id, self._auth_token = auth.split(':')
# Set base URL
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
@ -66,7 +66,7 @@ class DataBridge:
) -> Dict[str, Any]:
"""Make authenticated HTTP request"""
headers = {"Authorization": f"Bearer {self._auth_token}"}
if not files:
headers["Content-Type"] = "application/json"
@ -89,14 +89,14 @@ class DataBridge:
) -> Document:
"""
Ingest a text document into DataBridge.
Args:
content: Text content to ingest
metadata: Optional metadata dictionary
Returns:
Document: Metadata of the ingested document
Example:
```python
doc = db.ingest_text(
@ -129,16 +129,16 @@ class DataBridge:
) -> Document:
"""
Ingest a file document into DataBridge.
Args:
file: File to ingest (path string, bytes, file object, or Path)
filename: Name of the file
content_type: MIME type (optional, will be guessed if not provided)
metadata: Optional metadata dictionary
Returns:
Document: Metadata of the ingested document
Example:
```python
# From file path
@ -148,7 +148,7 @@ class DataBridge:
content_type="application/pdf",
metadata={"department": "research"}
)
# From file object
with open("document.pdf", "rb") as f:
doc = db.ingest_file(f, "document.pdf")
@ -188,58 +188,125 @@ class DataBridge:
if isinstance(file, (str, Path)):
file_obj.close()
def query(
def retrieve_chunks(
self,
query: str,
return_type: str = "chunks",
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> Union[List[ChunkResult], List[DocumentResult]]:
) -> List[ChunkResult]:
"""
Query documents in DataBridge.
Retrieve relevant chunks.
Args:
query: Search query text
return_type: Type of results ("chunks" or "documents")
filters: Optional metadata filters
k: Number of results (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
Returns:
List[ChunkResult] or List[DocumentResult] depending on return_type
List[ChunkResult]
Example:
```python
# Query for chunks
chunks = db.query(
chunks = db.retrieve_chunks(
"What are the key findings?",
return_type="chunks",
filters={"department": "research"}
)
```
"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score
}
# Query for documents
docs = db.query(
response = self._request("POST", "search/chunks", request)
return [ChunkResult(**r) for r in response]
def retrieve_docs(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0
) -> List[DocumentResult]:
"""
Retrieve relevant documents.
Args:
query: Search query text
filters: Optional metadata filters
k: Number of results (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
Returns:
List[DocumentResult]
Example:
```python
docs = db.retrieve_docs(
"machine learning",
return_type="documents",
k=5
)
```
"""
request = {
"query": query,
"return_type": return_type,
"filters": filters,
"k": k,
"min_score": min_score
}
response = self._request("POST", "query", request)
if return_type == "chunks":
return [ChunkResult(**r) for r in response]
response = self._request("POST", "retrieve/docs", request)
return [DocumentResult(**r) for r in response]
def query(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> CompletionResponse:
"""
Generate completion using relevant chunks as context.
Args:
query: Query text
filters: Optional metadata filters
k: Number of chunks to use as context (default: 4)
min_score: Minimum similarity threshold (default: 0.0)
max_tokens: Maximum tokens in completion
temperature: Model temperature
Returns:
CompletionResponse
Example:
```python
response = db.query(
"What are the key findings about customer satisfaction?",
filters={"department": "research"},
temperature=0.7
)
print(response.completion)
```
"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score,
"max_tokens": max_tokens,
"temperature": temperature,
}
response = self._request("POST", "query", request)
return CompletionResponse(**response)
def list_documents(
self,
skip: int = 0,
@ -248,7 +315,7 @@ class DataBridge:
) -> List[Document]:
"""
List accessible documents.
Args:
skip: Number of documents to skip
limit: Maximum number of documents to return
@ -256,12 +323,12 @@ class DataBridge:
Returns:
List[Document]: List of accessible documents
Example:
```python
# Get first page
docs = db.list_documents(limit=10)
# Get next page
next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"})
```
@ -275,13 +342,13 @@ class DataBridge:
def get_document(self, document_id: str) -> Document:
"""
Get document metadata by ID.
Args:
document_id: ID of the document
Returns:
Document: Document metadata
Example:
```python
doc = db.get_document("doc_123")

View File

@ -15,6 +15,7 @@ dependencies = [
"httpx>=0.24.0",
"pyjwt>=2.0.0",
"pydantic==2.10.3",
"requests>=2.32.3",
]
[tool.hatch.build.targets.wheel]