From 03345dcc07414967dcb0ba12570e44e38b6a8b54 Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Thu, 26 Dec 2024 08:52:25 -0500 Subject: [PATCH] Add completions API (#3) --- config.toml | 3 +- core/api.py | 60 ++++++- core/{planner => completion}/__init__.py | 0 core/completion/base_completion.py | 26 +++ core/completion/openai_completion.py | 37 +++++ core/config.py | 130 ++++++--------- core/models/request.py | 12 +- core/planner/base_planner.py | 9 -- core/planner/simple_planner.py | 17 -- core/services/document_service.py | 144 ++++++++++------- core/tests/integration/test_api.py | 193 +++++++++++++++-------- sdks/python/databridge/async_.py | 139 +++++++++++----- sdks/python/databridge/models.py | 6 + sdks/python/databridge/sync.py | 139 +++++++++++----- sdks/python/pyproject.toml | 1 + 15 files changed, 599 insertions(+), 317 deletions(-) rename core/{planner => completion}/__init__.py (100%) create mode 100644 core/completion/base_completion.py create mode 100644 core/completion/openai_completion.py delete mode 100644 core/planner/base_planner.py delete mode 100644 core/planner/simple_planner.py diff --git a/config.toml b/config.toml index 3719815..fbe7bb0 100644 --- a/config.toml +++ b/config.toml @@ -3,7 +3,7 @@ default_region = "us-east-2" default_bucket_name = "databridge-s3-storage" [mongodb] -database_name = "databridge" +database_name = "DataBridgeTest" documents_collection = "documents" chunks_collection = "document_chunks" @@ -13,6 +13,7 @@ index_name = "vector_index" [model] embedding_model = "text-embedding-3-small" +completion_model = "gpt-3.5-turbo" [document_processing] chunk_size = 1000 diff --git a/core/api.py b/core/api.py index 03f8139..707610d 100644 --- a/core/api.py +++ b/core/api.py @@ -12,13 +12,14 @@ from fastapi import ( from fastapi.middleware.cors import CORSMiddleware import jwt import logging -from core.models.request import IngestTextRequest, QueryRequest +from core.models.request import IngestTextRequest, RetrieveRequest, CompletionQueryRequest from core.models.documents import ( Document, DocumentResult, ChunkResult ) from core.models.auth import AuthContext, EntityType +from core.completion.base_completion import CompletionResponse from core.services.document_service import DocumentService from core.config import get_settings from core.database.mongo_database import MongoDatabase @@ -26,7 +27,7 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore from core.storage.s3_storage import S3Storage from core.parser.unstructured_parser import UnstructuredAPIParser from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel - +from core.completion.openai_completion import OpenAICompletionModel # Initialize FastAPI app app = FastAPI(title="DataBridge API") @@ -76,13 +77,18 @@ embedding_model = OpenAIEmbeddingModel( model_name=settings.EMBEDDING_MODEL ) +completion_model = OpenAICompletionModel( + model_name=settings.COMPLETION_MODEL +) + # Initialize document service document_service = DocumentService( database=database, vector_store=vector_store, storage=storage, parser=parser, - embedding_model=embedding_model + embedding_model=embedding_model, + completion_model=completion_model ) @@ -150,13 +156,51 @@ async def ingest_file( raise HTTPException(400, "Invalid metadata JSON") -@app.post("/query", response_model=Union[List[ChunkResult], List[DocumentResult]]) -async def query_documents( - request: QueryRequest, +@app.post("/retrieve/chunks", response_model=List[ChunkResult]) +async def retrieve_chunks( + request: RetrieveRequest, auth: AuthContext = Depends(verify_token) ): - """Query documents with specified return type.""" - return await document_service.query(request, auth) + """Retrieve relevant chunks.""" + return await document_service.retrieve_chunks( + request.query, + auth, + request.filters, + request.k, + request.min_score + ) + + +@app.post("/retrieve/docs", response_model=List[DocumentResult]) +async def retrieve_documents( + request: RetrieveRequest, + auth: AuthContext = Depends(verify_token) +): + """Retrieve relevant documents.""" + return await document_service.retrieve_docs( + request.query, + auth, + request.filters, + request.k, + request.min_score + ) + + +@app.post("/query", response_model=CompletionResponse) +async def query_completion( + request: CompletionQueryRequest, + auth: AuthContext = Depends(verify_token) +): + """Generate completion using relevant chunks as context.""" + return await document_service.query( + request.query, + auth, + request.filters, + request.k, + request.min_score, + request.max_tokens, + request.temperature + ) @app.get("/documents", response_model=List[Document]) diff --git a/core/planner/__init__.py b/core/completion/__init__.py similarity index 100% rename from core/planner/__init__.py rename to core/completion/__init__.py diff --git a/core/completion/base_completion.py b/core/completion/base_completion.py new file mode 100644 index 0000000..e576400 --- /dev/null +++ b/core/completion/base_completion.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Dict +from pydantic import BaseModel + + +class CompletionResponse(BaseModel): + """Response from completion generation""" + completion: str + usage: Dict[str, int] + + +class CompletionRequest(BaseModel): + """Request for completion generation""" + query: str + context_chunks: List[str] + max_tokens: Optional[int] = 1000 + temperature: Optional[float] = 0.7 + + +class BaseCompletionModel(ABC): + """Base class for completion models""" + + @abstractmethod + async def complete(self, request: CompletionRequest) -> CompletionResponse: + """Generate completion from query and context""" + pass diff --git a/core/completion/openai_completion.py b/core/completion/openai_completion.py new file mode 100644 index 0000000..dc7defc --- /dev/null +++ b/core/completion/openai_completion.py @@ -0,0 +1,37 @@ +from .base_completion import BaseCompletionModel, CompletionRequest, CompletionResponse + + +class OpenAICompletionModel(BaseCompletionModel): + """OpenAI completion model implementation""" + + def __init__(self, model_name: str): + self.model_name = model_name + # Import here to avoid dependency if not using OpenAI + from openai import AsyncOpenAI + self.client = AsyncOpenAI() + + async def complete(self, request: CompletionRequest) -> CompletionResponse: + """Generate completion using OpenAI API""" + # Construct prompt with context + context = "\n\n".join(request.context_chunks) + messages = [ + {"role": "system", "content": "You are a helpful assistant. Use the provided context to answer questions accurately."}, + {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {request.query}"} + ] + + # Call OpenAI API + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=request.max_tokens, + temperature=request.temperature, + ) + + return CompletionResponse( + completion=response.choices[0].message.content, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens + } + ) diff --git a/core/config.py b/core/config.py index bb93fc8..30ad520 100644 --- a/core/config.py +++ b/core/config.py @@ -1,104 +1,66 @@ -from typing import Dict, Any from pydantic import Field from pydantic_settings import BaseSettings from functools import lru_cache import tomli -from dotenv import load_dotenv, find_dotenv +from dotenv import load_dotenv -def load_toml_config() -> Dict[Any, Any]: - """Load configuration from config.toml file.""" - with open("config.toml", "rb") as f: - return tomli.load(f) class Settings(BaseSettings): """DataBridge configuration settings.""" - - # MongoDB settings + # Required environment variables MONGODB_URI: str = Field(..., env="MONGODB_URI") - DATABRIDGE_DB: str = Field(None) - - # Collection names - DOCUMENTS_COLLECTION: str = Field(None) - CHUNKS_COLLECTION: str = Field(None) - - # Vector search settings - VECTOR_INDEX_NAME: str = Field(None) - - # API Keys OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY") - - # Optional API keys for alternative models - ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY") - COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY") - VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY") - - # Model settings - EMBEDDING_MODEL: str = Field("text-embedding-3-small") - - # Document processing settings - CHUNK_SIZE: int = Field(1000) - CHUNK_OVERLAP: int = Field(200) - DEFAULT_K: int = Field(4) - - # Storage settings AWS_ACCESS_KEY: str = Field(..., env="AWS_ACCESS_KEY") AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY") - AWS_REGION: str = Field(None) - S3_BUCKET: str = Field(None) - - # Auth settings JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY") - JWT_ALGORITHM: str = Field("HS256") - - # Server settings - HOST: str = Field("localhost") - PORT: int = Field(8000) - RELOAD: bool = Field(False) - class Config: - env_file = ".env" - case_sensitive = True - extra = "allow" + # Values from config.toml with defaults + AWS_REGION: str = "us-east-2" + S3_BUCKET: str = "databridge-s3-storage" + DATABRIDGE_DB: str = "databridge" + DOCUMENTS_COLLECTION: str = "documents" + CHUNKS_COLLECTION: str = "document_chunks" + VECTOR_INDEX_NAME: str = "vector_index" + VECTOR_DIMENSIONS: int = 1536 + EMBEDDING_MODEL: str = "text-embedding-3-small" + COMPLETION_MODEL: str = "gpt-3.5-turbo" + CHUNK_SIZE: int = 1000 + CHUNK_OVERLAP: int = 200 + DEFAULT_K: int = 4 + HOST: str = "localhost" + PORT: int = 8000 + RELOAD: bool = False + JWT_ALGORITHM: str = "HS256" - def __init__(self, **kwargs): - # Force reload of environment variables - load_dotenv(find_dotenv(), override=True) - - config = load_toml_config() - - # Set values from config.toml - kwargs.update({ - # MongoDB settings - "DATABRIDGE_DB": config["mongodb"]["database_name"], - "DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"], - "CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"], - "VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"], - - # AWS settings - "AWS_REGION": config["aws"]["default_region"], - "S3_BUCKET": config["aws"]["default_bucket_name"], - - # Model settings - "EMBEDDING_MODEL": config["model"]["embedding_model"], - - # Document processing settings - "CHUNK_SIZE": config["document_processing"]["chunk_size"], - "CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"], - "DEFAULT_K": config["document_processing"]["default_k"], - - # Server settings - "HOST": config["server"]["host"], - "PORT": config["server"]["port"], - "RELOAD": config["server"]["reload"], - - # Auth settings - "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], - }) - - super().__init__(**kwargs) @lru_cache() def get_settings() -> Settings: """Get cached settings instance.""" - return Settings() + load_dotenv() + + # Load config.toml + with open("config.toml", "rb") as f: + config = tomli.load(f) + + # Map config.toml values to settings + settings_dict = { + "AWS_REGION": config["aws"]["default_region"], + "S3_BUCKET": config["aws"]["default_bucket_name"], + "DATABRIDGE_DB": config["mongodb"]["database_name"], + "DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"], + "CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"], + "VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"], + "VECTOR_DIMENSIONS": config["mongodb"]["vector"]["dimensions"], + "EMBEDDING_MODEL": config["model"]["embedding_model"], + "COMPLETION_MODEL": config["model"]["completion_model"], + "CHUNK_SIZE": config["document_processing"]["chunk_size"], + "CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"], + "DEFAULT_K": config["document_processing"]["default_k"], + "HOST": config["server"]["host"], + "PORT": config["server"]["port"], + "RELOAD": config["server"]["reload"], + "JWT_ALGORITHM": config["auth"]["jwt_algorithm"], + } + + return Settings(**settings_dict) diff --git a/core/models/request.py b/core/models/request.py index ed50c87..34a3261 100644 --- a/core/models/request.py +++ b/core/models/request.py @@ -1,6 +1,5 @@ from typing import Dict, Any, Optional from pydantic import BaseModel, Field -from .documents import QueryReturnType class IngestTextRequest(BaseModel): @@ -9,10 +8,15 @@ class IngestTextRequest(BaseModel): metadata: Dict[str, Any] = Field(default_factory=dict) -class QueryRequest(BaseModel): - """Query request model with validation""" +class RetrieveRequest(BaseModel): + """Base retrieve request model""" query: str = Field(..., min_length=1) - return_type: QueryReturnType = QueryReturnType.CHUNKS filters: Optional[Dict[str, Any]] = None k: int = Field(default=4, gt=0) min_score: float = Field(default=0.0) + + +class CompletionQueryRequest(RetrieveRequest): + """Request model for completion generation""" + max_tokens: Optional[int] = None + temperature: Optional[float] = None diff --git a/core/planner/base_planner.py b/core/planner/base_planner.py deleted file mode 100644 index ed2d7ad..0000000 --- a/core/planner/base_planner.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, Any - - -class BasePlanner(ABC): - @abstractmethod - def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]: - """Create execution plan for retrieval""" - pass diff --git a/core/planner/simple_planner.py b/core/planner/simple_planner.py deleted file mode 100644 index 733e063..0000000 --- a/core/planner/simple_planner.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Dict, Any -from .base_planner import BasePlanner - - -class SimpleRAGPlanner(BasePlanner): - def __init__(self, default_k: int = 3): - self.default_k = default_k - - def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]: - """Create a simple retrieval plan.""" - return { - "strategy": "simple_retrieval", - "k": kwargs.get("k", self.default_k), - "query": query, - "filters": kwargs.get("filters", {}), - "min_score": kwargs.get("min_score", 0.0) - } diff --git a/core/services/document_service.py b/core/services/document_service.py index a53d737..f12dbd8 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -1,24 +1,18 @@ -from typing import Any, Dict, List, Union -import logging -from fastapi import UploadFile import base64 +from typing import Dict, Any, List, Optional +from fastapi import UploadFile +from core.models.request import IngestTextRequest +from ..models.documents import Document, DocumentChunk, ChunkResult, DocumentContent, DocumentResult +from ..models.auth import AuthContext from core.database.base_database import BaseDatabase -from core.embedding_model.base_embedding_model import BaseEmbeddingModel -from core.models.request import IngestTextRequest, QueryRequest -from core.parser.base_parser import BaseParser from core.storage.base_storage import BaseStorage from core.vector_store.base_vector_store import BaseVectorStore -from ..models.documents import ( - Document, - DocumentChunk, - ChunkResult, - DocumentContent, - DocumentResult, - QueryReturnType -) -from ..models.auth import AuthContext - +from core.embedding_model.base_embedding_model import BaseEmbeddingModel +from core.parser.base_parser import BaseParser +from core.completion.base_completion import BaseCompletionModel +from core.completion.base_completion import CompletionRequest, CompletionResponse +import logging logger = logging.getLogger(__name__) @@ -30,13 +24,91 @@ class DocumentService: vector_store: BaseVectorStore, storage: BaseStorage, parser: BaseParser, - embedding_model: BaseEmbeddingModel + embedding_model: BaseEmbeddingModel, + completion_model: BaseCompletionModel ): self.db = database self.vector_store = vector_store self.storage = storage self.parser = parser self.embedding_model = embedding_model + self.completion_model = completion_model + + async def retrieve_chunks( + self, + query: str, + auth: AuthContext, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0 + ) -> List[ChunkResult]: + """Retrieve relevant chunks.""" + # Get embedding for query + query_embedding = await self.embedding_model.embed_for_query(query) + logger.info("Generated query embedding") + + # Find authorized documents + doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters) + if not doc_ids: + logger.info("No authorized documents found") + return [] + logger.info(f"Found {len(doc_ids)} authorized documents") + + # Search chunks with vector similarity + chunks = await self.vector_store.query_similar( + query_embedding, + k=k, + doc_ids=doc_ids, + ) + logger.info(f"Found {len(chunks)} similar chunks") + + # Create and return chunk results + results = await self._create_chunk_results(auth, chunks) + logger.info(f"Returning {len(results)} chunk results") + return results + + async def retrieve_docs( + self, + query: str, + auth: AuthContext, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0 + ) -> List[DocumentResult]: + """Retrieve relevant documents.""" + # Get chunks first + chunks = await self.retrieve_chunks(query, auth, filters, k, min_score) + + # Convert to document results + results = await self._create_document_results(auth, chunks) + logger.info(f"Returning {len(results)} document results") + return results + + async def query( + self, + query: str, + auth: AuthContext, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None + ) -> CompletionResponse: + """Generate completion using relevant chunks as context.""" + # Get relevant chunks + chunks = await self.retrieve_chunks(query, auth, filters, k, min_score) + chunk_contents = [chunk.content for chunk in chunks] + + # Generate completion + request = CompletionRequest( + query=query, + context_chunks=chunk_contents, + max_tokens=max_tokens, + temperature=temperature + ) + + response = await self.completion_model.complete(request) + return response async def ingest_text( self, @@ -156,44 +228,6 @@ class DocumentService: return doc - async def query( - self, - request: QueryRequest, - auth: AuthContext - ) -> Union[List[ChunkResult], List[DocumentResult]]: - """Query documents with specified return type.""" - # TODO: k does not make sense for Documents, it's about chunks. - # We should also look into document ordering. Figure these out. - - # 1. Get embedding for query - query_embedding = await self.embedding_model.embed_for_query(request.query) - logger.info("Generated query embedding") - - # 2. Find authorized documents - doc_ids = await self.db.find_authorized_and_filtered_documents(auth, request.filters) - if not doc_ids: - logger.info("No authorized documents found") - return [] - logger.info(f"Found {len(doc_ids)} authorized documents") - - # 3. Search chunks with vector similarity - chunks = await self.vector_store.query_similar( - query_embedding, - k=request.k, - doc_ids=doc_ids, - ) - logger.info(f"Found {len(chunks)} similar chunks") - - # 4. Return results in requested format - if request.return_type == QueryReturnType.CHUNKS: - results = await self._create_chunk_results(auth, chunks) - logger.info(f"Returning {len(results)} chunk results") - return results - else: - results = await self._create_document_results(auth, chunks) - logger.info(f"Returning {len(results)} document results") - return results - def _create_chunk_objects( self, doc_id: str, diff --git a/core/tests/integration/test_api.py b/core/tests/integration/test_api.py index 7006166..bf751ee 100644 --- a/core/tests/integration/test_api.py +++ b/core/tests/integration/test_api.py @@ -259,71 +259,6 @@ async def test_auth_insufficient_permissions(client: AsyncClient): assert response.status_code == 403 -@pytest.mark.asyncio -async def test_query_chunks(client: AsyncClient): - """Test querying document chunks""" - # First ingest a document to query - doc_id = await test_ingest_text_document( - client, - content="The quick brown fox jumps over the lazy dog" - ) - - headers = create_auth_header() - # Sleep to allow time for document to be indexed - await asyncio.sleep(1) - - response = await client.post( - "/query", - json={ - "query": "jumping fox", - "return_type": "chunks", - "k": 1 - }, - headers=headers - ) - logger.info(f"Query response: {response.json()}") - - assert response.status_code == 200 - results = list(response.json()) - logger.info(f"Query results: {results}") - assert len(results) == 1 - assert results[0]["score"] > 0.5 - assert results[0]["document_id"] == doc_id - - -@pytest.mark.asyncio -async def test_query_documents(client: AsyncClient): - """Test querying for full documents""" - # First ingest a document to query - content = ( - "Headaches can significantly impact daily life and wellbeing. " - "Common triggers include stress, dehydration, and poor sleep habits. " - "While over-the-counter pain relievers may provide temporary relief, " - "it's important to identify and address the root causes. " - "Maintaining good health through proper nutrition, regular exercise, " - "and stress management can help prevent chronic headaches." - ) - doc_id = await test_ingest_text_document(client, content=content) - - headers = create_auth_header() - response = await client.post( - "/query", - json={ - "query": "Headaches, dehydration", - "return_type": "documents", - "filters": {"test": True} - }, - headers=headers - ) - - assert response.status_code == 200 - results = list(response.json()) - assert len(results) > 0 - assert results[0]["document_id"] == doc_id - assert "score" in results[0] - assert "metadata" in results[0] - - @pytest.mark.asyncio async def test_list_documents(client: AsyncClient): """Test listing documents""" @@ -367,15 +302,139 @@ async def test_invalid_document_id(client: AsyncClient): @pytest.mark.asyncio -async def test_invalid_query_params(client: AsyncClient): - """Test error handling for invalid query parameters""" +async def test_retrieve_chunks(client: AsyncClient): + """Test retrieving document chunks""" + # First ingest a document to search + doc_id = await test_ingest_text_document( + client, + content="The quick brown fox jumps over the lazy dog" + ) + + headers = create_auth_header() + # Sleep to allow time for document to be indexed + await asyncio.sleep(1) + + response = await client.post( + "/retrieve/chunks", + json={ + "query": "jumping fox", + "k": 1, + "min_score": 0.0 + }, + headers=headers + ) + + assert response.status_code == 200 + results = list(response.json()) + assert len(results) == 1 + assert results[0]["score"] > 0.5 + assert results[0]["document_id"] == doc_id + + +@pytest.mark.asyncio +async def test_retrieve_docs(client: AsyncClient): + """Test retrieving full documents""" + # First ingest a document to search + content = ( + "Headaches can significantly impact daily life and wellbeing. " + "Common triggers include stress, dehydration, and poor sleep habits. " + "While over-the-counter pain relievers may provide temporary relief, " + "it's important to identify and address the root causes. " + "Maintaining good health through proper nutrition, regular exercise, " + "and stress management can help prevent chronic headaches." + ) + doc_id = await test_ingest_text_document(client, content=content) + + headers = create_auth_header() + response = await client.post( + "/retrieve/docs", + json={ + "query": "Headaches, dehydration", + "filters": {"test": True} + }, + headers=headers + ) + + assert response.status_code == 200 + results = list(response.json()) + assert len(results) > 0 + assert results[0]["document_id"] == doc_id + assert "score" in results[0] + assert "metadata" in results[0] + + +@pytest.mark.asyncio +async def test_query_completion(client: AsyncClient): + """Test generating completions from context""" + # First ingest a document to use as context + content = ( + "The benefits of exercise are numerous. Regular physical activity " + "can improve cardiovascular health, strengthen muscles, enhance mental " + "wellbeing, and help maintain a healthy weight. Studies show that " + "even moderate exercise like walking can significantly reduce the risk " + "of various health conditions." + ) + await test_ingest_text_document(client, content=content) + headers = create_auth_header() response = await client.post( "/query", + json={ + "query": "What are the main benefits of exercise?", + "k": 2, + "temperature": 0.7, + "max_tokens": 100 + }, + headers=headers + ) + + assert response.status_code == 200 + result = response.json() + assert "completion" in result + assert "usage" in result + assert len(result["completion"]) > 0 + + +@pytest.mark.asyncio +async def test_invalid_retrieve_params(client: AsyncClient): + """Test error handling for invalid retrieve parameters""" + headers = create_auth_header() + + # Test empty query + response = await client.post( + "/retrieve/chunks", json={ "query": "", # Empty query + "k": 1 + }, + headers=headers + ) + assert response.status_code == 422 + + # Test invalid k + response = await client.post( + "/retrieve/docs", + json={ + "query": "test", "k": -1 # Invalid k }, headers=headers ) assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_invalid_completion_params(client: AsyncClient): + """Test error handling for invalid completion parameters""" + headers = create_auth_header() + + # Test empty query + response = await client.post( + "/query", + json={ + "query": "", # Empty query + "temperature": 2.0 # Invalid temperature + }, + headers=headers + ) + assert response.status_code == 422 diff --git a/sdks/python/databridge/async_.py b/sdks/python/databridge/async_.py index 09d941a..3f3f191 100644 --- a/sdks/python/databridge/async_.py +++ b/sdks/python/databridge/async_.py @@ -7,18 +7,18 @@ from urllib.parse import urlparse import httpx import jwt -from .models import Document, IngestTextRequest, ChunkResult, DocumentResult +from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse class AsyncDataBridge: """ DataBridge client for document operations. - + Args: uri (str): DataBridge URI in the format "databridge://:@" timeout (int, optional): Request timeout in seconds. Defaults to 30. is_local (bool, optional): Whether to connect to a local server. Defaults to False. - + Examples: ```python async with AsyncDataBridge("databridge://owner_id:token@api.databridge.ai") as db: @@ -27,7 +27,7 @@ class AsyncDataBridge: "Sample content", metadata={"category": "sample"} ) - + # Query documents results = await db.query("search query") ``` @@ -52,7 +52,7 @@ class AsyncDataBridge: # Split host and auth parts auth, host = parsed.netloc.split('@') self._owner_id, self._auth_token = auth.split(':') - + # Set base URL self._base_url = f"{'http' if self._is_local else 'https'}://{host}" @@ -68,7 +68,7 @@ class AsyncDataBridge: ) -> Dict[str, Any]: """Make authenticated HTTP request""" headers = {"Authorization": f"Bearer {self._auth_token}"} - + if not files: headers["Content-Type"] = "application/json" @@ -90,14 +90,14 @@ class AsyncDataBridge: ) -> Document: """ Ingest a text document into DataBridge. - + Args: content: Text content to ingest metadata: Optional metadata dictionary - + Returns: Document: Metadata of the ingested document - + Example: ```python doc = await db.ingest_text( @@ -130,16 +130,16 @@ class AsyncDataBridge: ) -> Document: """ Ingest a file document into DataBridge. - + Args: file: File to ingest (path string, bytes, file object, or Path) filename: Name of the file content_type: MIME type (optional, will be guessed if not provided) metadata: Optional metadata dictionary - + Returns: Document: Metadata of the ingested document - + Example: ```python # From file path @@ -149,7 +149,7 @@ class AsyncDataBridge: content_type="application/pdf", metadata={"department": "research"} ) - + # From file object with open("document.pdf", "rb") as f: doc = await db.ingest_file(f, "document.pdf") @@ -189,58 +189,125 @@ class AsyncDataBridge: if isinstance(file, (str, Path)): file_obj.close() - async def query( + async def retrieve_chunks( self, query: str, - return_type: str = "chunks", filters: Optional[Dict[str, Any]] = None, k: int = 4, min_score: float = 0.0 - ) -> Union[List[ChunkResult], List[DocumentResult]]: + ) -> List[ChunkResult]: """ - Query documents in DataBridge. + Search for relevant chunks. Args: query: Search query text - return_type: Type of results ("chunks" or "documents") filters: Optional metadata filters k: Number of results (default: 4) min_score: Minimum similarity threshold (default: 0.0) Returns: - List[ChunkResult] or List[DocumentResult] depending on return_type + List[ChunkResult] Example: ```python - # Query for chunks - chunks = await db.query( + chunks = await db.retrieve_chunks( "What are the key findings?", - return_type="chunks", filters={"department": "research"} ) + ``` + """ + request = { + "query": query, + "filters": filters, + "k": k, + "min_score": min_score + } - # Query for documents - docs = await db.query( + response = await self._request("POST", "retrieve/chunks", request) + return [ChunkResult(**r) for r in response] + + async def retrieve_docs( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0 + ) -> List[DocumentResult]: + """ + Retrieve relevant documents. + + Args: + query: Search query text + filters: Optional metadata filters + k: Number of results (default: 4) + min_score: Minimum similarity threshold (default: 0.0) + + Returns: + List[DocumentResult] + + Example: + ```python + docs = await db.retrieve_docs( "machine learning", - return_type="documents", k=5 ) ``` """ request = { "query": query, - "return_type": return_type, "filters": filters, "k": k, "min_score": min_score } - response = await self._request("POST", "query", request) - - if return_type == "chunks": - return [ChunkResult(**r) for r in response] + response = await self._request("POST", "retrieve/docs", request) return [DocumentResult(**r) for r in response] + async def query( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> CompletionResponse: + """ + Generate completion using relevant chunks as context. + + Args: + query: Query text + filters: Optional metadata filters + k: Number of chunks to use as context (default: 4) + min_score: Minimum similarity threshold (default: 0.0) + max_tokens: Maximum tokens in completion + temperature: Model temperature + + Returns: + CompletionResponse + + Example: + ```python + response = await db.query( + "What are the key findings about customer satisfaction?", + filters={"department": "research"}, + temperature=0.7 + ) + print(response.completion) + ``` + """ + request = { + "query": query, + "filters": filters, + "k": k, + "min_score": min_score, + "max_tokens": max_tokens, + "temperature": temperature, + } + + response = await self._request("POST", "query", request) + return CompletionResponse(**response) + async def list_documents( self, skip: int = 0, @@ -249,7 +316,7 @@ class AsyncDataBridge: ) -> List[Document]: """ List accessible documents. - + Args: skip: Number of documents to skip limit: Maximum number of documents to return @@ -257,12 +324,12 @@ class AsyncDataBridge: Returns: List[Document]: List of accessible documents - + Example: ```python # Get first page docs = await db.list_documents(limit=10) - + # Get next page next_page = await db.list_documents(skip=10, limit=10, filters={"department": "research"}) ``` @@ -276,13 +343,13 @@ class AsyncDataBridge: async def get_document(self, document_id: str) -> Document: """ Get document metadata by ID. - + Args: document_id: ID of the document - + Returns: Document: Document metadata - + Example: ```python doc = await db.get_document("doc_123") diff --git a/sdks/python/databridge/models.py b/sdks/python/databridge/models.py index a98f0e3..0eeb7e4 100644 --- a/sdks/python/databridge/models.py +++ b/sdks/python/databridge/models.py @@ -53,3 +53,9 @@ class DocumentResult(BaseModel): document_id: str = Field(..., description="Document ID") metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata") content: DocumentContent = Field(..., description="Document content or URL") + + +class CompletionResponse(BaseModel): + """Completion response model""" + completion: str + usage: Dict[str, int] diff --git a/sdks/python/databridge/sync.py b/sdks/python/databridge/sync.py index 82d495f..876d331 100644 --- a/sdks/python/databridge/sync.py +++ b/sdks/python/databridge/sync.py @@ -7,18 +7,18 @@ from urllib.parse import urlparse import jwt import requests -from .models import Document, IngestTextRequest, ChunkResult, DocumentResult +from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse class DataBridge: """ DataBridge client for document operations. - + Args: uri (str): DataBridge URI in the format "databridge://:@" timeout (int, optional): Request timeout in seconds. Defaults to 30. is_local (bool, optional): Whether connecting to local development server. Defaults to False. - + Examples: ```python with DataBridge("databridge://owner_id:token@api.databridge.ai") as db: @@ -27,7 +27,7 @@ class DataBridge: "Sample content", metadata={"category": "sample"} ) - + # Query documents results = db.query("search query") ``` @@ -50,7 +50,7 @@ class DataBridge: # Split host and auth parts auth, host = parsed.netloc.split('@') self._owner_id, self._auth_token = auth.split(':') - + # Set base URL self._base_url = f"{'http' if self._is_local else 'https'}://{host}" @@ -66,7 +66,7 @@ class DataBridge: ) -> Dict[str, Any]: """Make authenticated HTTP request""" headers = {"Authorization": f"Bearer {self._auth_token}"} - + if not files: headers["Content-Type"] = "application/json" @@ -89,14 +89,14 @@ class DataBridge: ) -> Document: """ Ingest a text document into DataBridge. - + Args: content: Text content to ingest metadata: Optional metadata dictionary - + Returns: Document: Metadata of the ingested document - + Example: ```python doc = db.ingest_text( @@ -129,16 +129,16 @@ class DataBridge: ) -> Document: """ Ingest a file document into DataBridge. - + Args: file: File to ingest (path string, bytes, file object, or Path) filename: Name of the file content_type: MIME type (optional, will be guessed if not provided) metadata: Optional metadata dictionary - + Returns: Document: Metadata of the ingested document - + Example: ```python # From file path @@ -148,7 +148,7 @@ class DataBridge: content_type="application/pdf", metadata={"department": "research"} ) - + # From file object with open("document.pdf", "rb") as f: doc = db.ingest_file(f, "document.pdf") @@ -188,58 +188,125 @@ class DataBridge: if isinstance(file, (str, Path)): file_obj.close() - def query( + def retrieve_chunks( self, query: str, - return_type: str = "chunks", filters: Optional[Dict[str, Any]] = None, k: int = 4, min_score: float = 0.0 - ) -> Union[List[ChunkResult], List[DocumentResult]]: + ) -> List[ChunkResult]: """ - Query documents in DataBridge. + Retrieve relevant chunks. Args: query: Search query text - return_type: Type of results ("chunks" or "documents") filters: Optional metadata filters k: Number of results (default: 4) min_score: Minimum similarity threshold (default: 0.0) Returns: - List[ChunkResult] or List[DocumentResult] depending on return_type + List[ChunkResult] Example: ```python - # Query for chunks - chunks = db.query( + chunks = db.retrieve_chunks( "What are the key findings?", - return_type="chunks", filters={"department": "research"} ) + ``` + """ + request = { + "query": query, + "filters": filters, + "k": k, + "min_score": min_score + } - # Query for documents - docs = db.query( + response = self._request("POST", "search/chunks", request) + return [ChunkResult(**r) for r in response] + + def retrieve_docs( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0 + ) -> List[DocumentResult]: + """ + Retrieve relevant documents. + + Args: + query: Search query text + filters: Optional metadata filters + k: Number of results (default: 4) + min_score: Minimum similarity threshold (default: 0.0) + + Returns: + List[DocumentResult] + + Example: + ```python + docs = db.retrieve_docs( "machine learning", - return_type="documents", k=5 ) ``` """ request = { "query": query, - "return_type": return_type, "filters": filters, "k": k, "min_score": min_score } - response = self._request("POST", "query", request) - - if return_type == "chunks": - return [ChunkResult(**r) for r in response] + response = self._request("POST", "retrieve/docs", request) return [DocumentResult(**r) for r in response] + def query( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + k: int = 4, + min_score: float = 0.0, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> CompletionResponse: + """ + Generate completion using relevant chunks as context. + + Args: + query: Query text + filters: Optional metadata filters + k: Number of chunks to use as context (default: 4) + min_score: Minimum similarity threshold (default: 0.0) + max_tokens: Maximum tokens in completion + temperature: Model temperature + + Returns: + CompletionResponse + + Example: + ```python + response = db.query( + "What are the key findings about customer satisfaction?", + filters={"department": "research"}, + temperature=0.7 + ) + print(response.completion) + ``` + """ + request = { + "query": query, + "filters": filters, + "k": k, + "min_score": min_score, + "max_tokens": max_tokens, + "temperature": temperature, + } + + response = self._request("POST", "query", request) + return CompletionResponse(**response) + def list_documents( self, skip: int = 0, @@ -248,7 +315,7 @@ class DataBridge: ) -> List[Document]: """ List accessible documents. - + Args: skip: Number of documents to skip limit: Maximum number of documents to return @@ -256,12 +323,12 @@ class DataBridge: Returns: List[Document]: List of accessible documents - + Example: ```python # Get first page docs = db.list_documents(limit=10) - + # Get next page next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"}) ``` @@ -275,13 +342,13 @@ class DataBridge: def get_document(self, document_id: str) -> Document: """ Get document metadata by ID. - + Args: document_id: ID of the document - + Returns: Document: Document metadata - + Example: ```python doc = db.get_document("doc_123") diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 2c87dce..e84dd30 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "httpx>=0.24.0", "pyjwt>=2.0.0", "pydantic==2.10.3", + "requests>=2.32.3", ] [tool.hatch.build.targets.wheel]