mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add completions API (#3)
This commit is contained in:
parent
a404154650
commit
03345dcc07
@ -3,7 +3,7 @@ default_region = "us-east-2"
|
||||
default_bucket_name = "databridge-s3-storage"
|
||||
|
||||
[mongodb]
|
||||
database_name = "databridge"
|
||||
database_name = "DataBridgeTest"
|
||||
documents_collection = "documents"
|
||||
chunks_collection = "document_chunks"
|
||||
|
||||
@ -13,6 +13,7 @@ index_name = "vector_index"
|
||||
|
||||
[model]
|
||||
embedding_model = "text-embedding-3-small"
|
||||
completion_model = "gpt-3.5-turbo"
|
||||
|
||||
[document_processing]
|
||||
chunk_size = 1000
|
||||
|
60
core/api.py
60
core/api.py
@ -12,13 +12,14 @@ from fastapi import (
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
import logging
|
||||
from core.models.request import IngestTextRequest, QueryRequest
|
||||
from core.models.request import IngestTextRequest, RetrieveRequest, CompletionQueryRequest
|
||||
from core.models.documents import (
|
||||
Document,
|
||||
DocumentResult,
|
||||
ChunkResult
|
||||
)
|
||||
from core.models.auth import AuthContext, EntityType
|
||||
from core.completion.base_completion import CompletionResponse
|
||||
from core.services.document_service import DocumentService
|
||||
from core.config import get_settings
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
@ -26,7 +27,7 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
from core.storage.s3_storage import S3Storage
|
||||
from core.parser.unstructured_parser import UnstructuredAPIParser
|
||||
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
|
||||
|
||||
from core.completion.openai_completion import OpenAICompletionModel
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="DataBridge API")
|
||||
@ -76,13 +77,18 @@ embedding_model = OpenAIEmbeddingModel(
|
||||
model_name=settings.EMBEDDING_MODEL
|
||||
)
|
||||
|
||||
completion_model = OpenAICompletionModel(
|
||||
model_name=settings.COMPLETION_MODEL
|
||||
)
|
||||
|
||||
# Initialize document service
|
||||
document_service = DocumentService(
|
||||
database=database,
|
||||
vector_store=vector_store,
|
||||
storage=storage,
|
||||
parser=parser,
|
||||
embedding_model=embedding_model
|
||||
embedding_model=embedding_model,
|
||||
completion_model=completion_model
|
||||
)
|
||||
|
||||
|
||||
@ -150,13 +156,51 @@ async def ingest_file(
|
||||
raise HTTPException(400, "Invalid metadata JSON")
|
||||
|
||||
|
||||
@app.post("/query", response_model=Union[List[ChunkResult], List[DocumentResult]])
|
||||
async def query_documents(
|
||||
request: QueryRequest,
|
||||
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
||||
async def retrieve_chunks(
|
||||
request: RetrieveRequest,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""Query documents with specified return type."""
|
||||
return await document_service.query(request, auth)
|
||||
"""Retrieve relevant chunks."""
|
||||
return await document_service.retrieve_chunks(
|
||||
request.query,
|
||||
auth,
|
||||
request.filters,
|
||||
request.k,
|
||||
request.min_score
|
||||
)
|
||||
|
||||
|
||||
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
||||
async def retrieve_documents(
|
||||
request: RetrieveRequest,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""Retrieve relevant documents."""
|
||||
return await document_service.retrieve_docs(
|
||||
request.query,
|
||||
auth,
|
||||
request.filters,
|
||||
request.k,
|
||||
request.min_score
|
||||
)
|
||||
|
||||
|
||||
@app.post("/query", response_model=CompletionResponse)
|
||||
async def query_completion(
|
||||
request: CompletionQueryRequest,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""Generate completion using relevant chunks as context."""
|
||||
return await document_service.query(
|
||||
request.query,
|
||||
auth,
|
||||
request.filters,
|
||||
request.k,
|
||||
request.min_score,
|
||||
request.max_tokens,
|
||||
request.temperature
|
||||
)
|
||||
|
||||
|
||||
@app.get("/documents", response_model=List[Document])
|
||||
|
26
core/completion/base_completion.py
Normal file
26
core/completion/base_completion.py
Normal file
@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion generation"""
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request for completion generation"""
|
||||
query: str
|
||||
context_chunks: List[str]
|
||||
max_tokens: Optional[int] = 1000
|
||||
temperature: Optional[float] = 0.7
|
||||
|
||||
|
||||
class BaseCompletionModel(ABC):
|
||||
"""Base class for completion models"""
|
||||
|
||||
@abstractmethod
|
||||
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
||||
"""Generate completion from query and context"""
|
||||
pass
|
37
core/completion/openai_completion.py
Normal file
37
core/completion/openai_completion.py
Normal file
@ -0,0 +1,37 @@
|
||||
from .base_completion import BaseCompletionModel, CompletionRequest, CompletionResponse
|
||||
|
||||
|
||||
class OpenAICompletionModel(BaseCompletionModel):
|
||||
"""OpenAI completion model implementation"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
# Import here to avoid dependency if not using OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
self.client = AsyncOpenAI()
|
||||
|
||||
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
||||
"""Generate completion using OpenAI API"""
|
||||
# Construct prompt with context
|
||||
context = "\n\n".join(request.context_chunks)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant. Use the provided context to answer questions accurately."},
|
||||
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {request.query}"}
|
||||
]
|
||||
|
||||
# Call OpenAI API
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
return CompletionResponse(
|
||||
completion=response.choices[0].message.content,
|
||||
usage={
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens
|
||||
}
|
||||
)
|
130
core/config.py
130
core/config.py
@ -1,104 +1,66 @@
|
||||
from typing import Dict, Any
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
import tomli
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def load_toml_config() -> Dict[Any, Any]:
|
||||
"""Load configuration from config.toml file."""
|
||||
with open("config.toml", "rb") as f:
|
||||
return tomli.load(f)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""DataBridge configuration settings."""
|
||||
|
||||
# MongoDB settings
|
||||
# Required environment variables
|
||||
MONGODB_URI: str = Field(..., env="MONGODB_URI")
|
||||
DATABRIDGE_DB: str = Field(None)
|
||||
|
||||
# Collection names
|
||||
DOCUMENTS_COLLECTION: str = Field(None)
|
||||
CHUNKS_COLLECTION: str = Field(None)
|
||||
|
||||
# Vector search settings
|
||||
VECTOR_INDEX_NAME: str = Field(None)
|
||||
|
||||
# API Keys
|
||||
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
|
||||
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
|
||||
|
||||
# Optional API keys for alternative models
|
||||
ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY")
|
||||
COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY")
|
||||
VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY")
|
||||
|
||||
# Model settings
|
||||
EMBEDDING_MODEL: str = Field("text-embedding-3-small")
|
||||
|
||||
# Document processing settings
|
||||
CHUNK_SIZE: int = Field(1000)
|
||||
CHUNK_OVERLAP: int = Field(200)
|
||||
DEFAULT_K: int = Field(4)
|
||||
|
||||
# Storage settings
|
||||
AWS_ACCESS_KEY: str = Field(..., env="AWS_ACCESS_KEY")
|
||||
AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY")
|
||||
AWS_REGION: str = Field(None)
|
||||
S3_BUCKET: str = Field(None)
|
||||
|
||||
# Auth settings
|
||||
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
|
||||
JWT_ALGORITHM: str = Field("HS256")
|
||||
|
||||
# Server settings
|
||||
HOST: str = Field("localhost")
|
||||
PORT: int = Field(8000)
|
||||
RELOAD: bool = Field(False)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
extra = "allow"
|
||||
# Values from config.toml with defaults
|
||||
AWS_REGION: str = "us-east-2"
|
||||
S3_BUCKET: str = "databridge-s3-storage"
|
||||
DATABRIDGE_DB: str = "databridge"
|
||||
DOCUMENTS_COLLECTION: str = "documents"
|
||||
CHUNKS_COLLECTION: str = "document_chunks"
|
||||
VECTOR_INDEX_NAME: str = "vector_index"
|
||||
VECTOR_DIMENSIONS: int = 1536
|
||||
EMBEDDING_MODEL: str = "text-embedding-3-small"
|
||||
COMPLETION_MODEL: str = "gpt-3.5-turbo"
|
||||
CHUNK_SIZE: int = 1000
|
||||
CHUNK_OVERLAP: int = 200
|
||||
DEFAULT_K: int = 4
|
||||
HOST: str = "localhost"
|
||||
PORT: int = 8000
|
||||
RELOAD: bool = False
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Force reload of environment variables
|
||||
load_dotenv(find_dotenv(), override=True)
|
||||
|
||||
config = load_toml_config()
|
||||
|
||||
# Set values from config.toml
|
||||
kwargs.update({
|
||||
# MongoDB settings
|
||||
"DATABRIDGE_DB": config["mongodb"]["database_name"],
|
||||
"DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"],
|
||||
"CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"],
|
||||
"VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"],
|
||||
|
||||
# AWS settings
|
||||
"AWS_REGION": config["aws"]["default_region"],
|
||||
"S3_BUCKET": config["aws"]["default_bucket_name"],
|
||||
|
||||
# Model settings
|
||||
"EMBEDDING_MODEL": config["model"]["embedding_model"],
|
||||
|
||||
# Document processing settings
|
||||
"CHUNK_SIZE": config["document_processing"]["chunk_size"],
|
||||
"CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"],
|
||||
"DEFAULT_K": config["document_processing"]["default_k"],
|
||||
|
||||
# Server settings
|
||||
"HOST": config["server"]["host"],
|
||||
"PORT": config["server"]["port"],
|
||||
"RELOAD": config["server"]["reload"],
|
||||
|
||||
# Auth settings
|
||||
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
|
||||
})
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
load_dotenv()
|
||||
|
||||
# Load config.toml
|
||||
with open("config.toml", "rb") as f:
|
||||
config = tomli.load(f)
|
||||
|
||||
# Map config.toml values to settings
|
||||
settings_dict = {
|
||||
"AWS_REGION": config["aws"]["default_region"],
|
||||
"S3_BUCKET": config["aws"]["default_bucket_name"],
|
||||
"DATABRIDGE_DB": config["mongodb"]["database_name"],
|
||||
"DOCUMENTS_COLLECTION": config["mongodb"]["documents_collection"],
|
||||
"CHUNKS_COLLECTION": config["mongodb"]["chunks_collection"],
|
||||
"VECTOR_INDEX_NAME": config["mongodb"]["vector"]["index_name"],
|
||||
"VECTOR_DIMENSIONS": config["mongodb"]["vector"]["dimensions"],
|
||||
"EMBEDDING_MODEL": config["model"]["embedding_model"],
|
||||
"COMPLETION_MODEL": config["model"]["completion_model"],
|
||||
"CHUNK_SIZE": config["document_processing"]["chunk_size"],
|
||||
"CHUNK_OVERLAP": config["document_processing"]["chunk_overlap"],
|
||||
"DEFAULT_K": config["document_processing"]["default_k"],
|
||||
"HOST": config["server"]["host"],
|
||||
"PORT": config["server"]["port"],
|
||||
"RELOAD": config["server"]["reload"],
|
||||
"JWT_ALGORITHM": config["auth"]["jwt_algorithm"],
|
||||
}
|
||||
|
||||
return Settings(**settings_dict)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from .documents import QueryReturnType
|
||||
|
||||
|
||||
class IngestTextRequest(BaseModel):
|
||||
@ -9,10 +8,15 @@ class IngestTextRequest(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""Query request model with validation"""
|
||||
class RetrieveRequest(BaseModel):
|
||||
"""Base retrieve request model"""
|
||||
query: str = Field(..., min_length=1)
|
||||
return_type: QueryReturnType = QueryReturnType.CHUNKS
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
k: int = Field(default=4, gt=0)
|
||||
min_score: float = Field(default=0.0)
|
||||
|
||||
|
||||
class CompletionQueryRequest(RetrieveRequest):
|
||||
"""Request model for completion generation"""
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
|
@ -1,9 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class BasePlanner(ABC):
|
||||
@abstractmethod
|
||||
def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Create execution plan for retrieval"""
|
||||
pass
|
@ -1,17 +0,0 @@
|
||||
from typing import Dict, Any
|
||||
from .base_planner import BasePlanner
|
||||
|
||||
|
||||
class SimpleRAGPlanner(BasePlanner):
|
||||
def __init__(self, default_k: int = 3):
|
||||
self.default_k = default_k
|
||||
|
||||
def plan_retrieval(self, query: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Create a simple retrieval plan."""
|
||||
return {
|
||||
"strategy": "simple_retrieval",
|
||||
"k": kwargs.get("k", self.default_k),
|
||||
"query": query,
|
||||
"filters": kwargs.get("filters", {}),
|
||||
"min_score": kwargs.get("min_score", 0.0)
|
||||
}
|
@ -1,24 +1,18 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
import logging
|
||||
from fastapi import UploadFile
|
||||
import base64
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import UploadFile
|
||||
|
||||
from core.models.request import IngestTextRequest
|
||||
from ..models.documents import Document, DocumentChunk, ChunkResult, DocumentContent, DocumentResult
|
||||
from ..models.auth import AuthContext
|
||||
from core.database.base_database import BaseDatabase
|
||||
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
|
||||
from core.models.request import IngestTextRequest, QueryRequest
|
||||
from core.parser.base_parser import BaseParser
|
||||
from core.storage.base_storage import BaseStorage
|
||||
from core.vector_store.base_vector_store import BaseVectorStore
|
||||
from ..models.documents import (
|
||||
Document,
|
||||
DocumentChunk,
|
||||
ChunkResult,
|
||||
DocumentContent,
|
||||
DocumentResult,
|
||||
QueryReturnType
|
||||
)
|
||||
from ..models.auth import AuthContext
|
||||
|
||||
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
|
||||
from core.parser.base_parser import BaseParser
|
||||
from core.completion.base_completion import BaseCompletionModel
|
||||
from core.completion.base_completion import CompletionRequest, CompletionResponse
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,13 +24,91 @@ class DocumentService:
|
||||
vector_store: BaseVectorStore,
|
||||
storage: BaseStorage,
|
||||
parser: BaseParser,
|
||||
embedding_model: BaseEmbeddingModel
|
||||
embedding_model: BaseEmbeddingModel,
|
||||
completion_model: BaseCompletionModel
|
||||
):
|
||||
self.db = database
|
||||
self.vector_store = vector_store
|
||||
self.storage = storage
|
||||
self.parser = parser
|
||||
self.embedding_model = embedding_model
|
||||
self.completion_model = completion_model
|
||||
|
||||
async def retrieve_chunks(
|
||||
self,
|
||||
query: str,
|
||||
auth: AuthContext,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> List[ChunkResult]:
|
||||
"""Retrieve relevant chunks."""
|
||||
# Get embedding for query
|
||||
query_embedding = await self.embedding_model.embed_for_query(query)
|
||||
logger.info("Generated query embedding")
|
||||
|
||||
# Find authorized documents
|
||||
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters)
|
||||
if not doc_ids:
|
||||
logger.info("No authorized documents found")
|
||||
return []
|
||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||
|
||||
# Search chunks with vector similarity
|
||||
chunks = await self.vector_store.query_similar(
|
||||
query_embedding,
|
||||
k=k,
|
||||
doc_ids=doc_ids,
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} similar chunks")
|
||||
|
||||
# Create and return chunk results
|
||||
results = await self._create_chunk_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} chunk results")
|
||||
return results
|
||||
|
||||
async def retrieve_docs(
|
||||
self,
|
||||
query: str,
|
||||
auth: AuthContext,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> List[DocumentResult]:
|
||||
"""Retrieve relevant documents."""
|
||||
# Get chunks first
|
||||
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
||||
|
||||
# Convert to document results
|
||||
results = await self._create_document_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} document results")
|
||||
return results
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
auth: AuthContext,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None
|
||||
) -> CompletionResponse:
|
||||
"""Generate completion using relevant chunks as context."""
|
||||
# Get relevant chunks
|
||||
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
||||
chunk_contents = [chunk.content for chunk in chunks]
|
||||
|
||||
# Generate completion
|
||||
request = CompletionRequest(
|
||||
query=query,
|
||||
context_chunks=chunk_contents,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
response = await self.completion_model.complete(request)
|
||||
return response
|
||||
|
||||
async def ingest_text(
|
||||
self,
|
||||
@ -156,44 +228,6 @@ class DocumentService:
|
||||
|
||||
return doc
|
||||
|
||||
async def query(
|
||||
self,
|
||||
request: QueryRequest,
|
||||
auth: AuthContext
|
||||
) -> Union[List[ChunkResult], List[DocumentResult]]:
|
||||
"""Query documents with specified return type."""
|
||||
# TODO: k does not make sense for Documents, it's about chunks.
|
||||
# We should also look into document ordering. Figure these out.
|
||||
|
||||
# 1. Get embedding for query
|
||||
query_embedding = await self.embedding_model.embed_for_query(request.query)
|
||||
logger.info("Generated query embedding")
|
||||
|
||||
# 2. Find authorized documents
|
||||
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, request.filters)
|
||||
if not doc_ids:
|
||||
logger.info("No authorized documents found")
|
||||
return []
|
||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||
|
||||
# 3. Search chunks with vector similarity
|
||||
chunks = await self.vector_store.query_similar(
|
||||
query_embedding,
|
||||
k=request.k,
|
||||
doc_ids=doc_ids,
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} similar chunks")
|
||||
|
||||
# 4. Return results in requested format
|
||||
if request.return_type == QueryReturnType.CHUNKS:
|
||||
results = await self._create_chunk_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} chunk results")
|
||||
return results
|
||||
else:
|
||||
results = await self._create_document_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} document results")
|
||||
return results
|
||||
|
||||
def _create_chunk_objects(
|
||||
self,
|
||||
doc_id: str,
|
||||
|
@ -259,71 +259,6 @@ async def test_auth_insufficient_permissions(client: AsyncClient):
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(client: AsyncClient):
|
||||
"""Test querying document chunks"""
|
||||
# First ingest a document to query
|
||||
doc_id = await test_ingest_text_document(
|
||||
client,
|
||||
content="The quick brown fox jumps over the lazy dog"
|
||||
)
|
||||
|
||||
headers = create_auth_header()
|
||||
# Sleep to allow time for document to be indexed
|
||||
await asyncio.sleep(1)
|
||||
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "jumping fox",
|
||||
"return_type": "chunks",
|
||||
"k": 1
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
logger.info(f"Query response: {response.json()}")
|
||||
|
||||
assert response.status_code == 200
|
||||
results = list(response.json())
|
||||
logger.info(f"Query results: {results}")
|
||||
assert len(results) == 1
|
||||
assert results[0]["score"] > 0.5
|
||||
assert results[0]["document_id"] == doc_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(client: AsyncClient):
|
||||
"""Test querying for full documents"""
|
||||
# First ingest a document to query
|
||||
content = (
|
||||
"Headaches can significantly impact daily life and wellbeing. "
|
||||
"Common triggers include stress, dehydration, and poor sleep habits. "
|
||||
"While over-the-counter pain relievers may provide temporary relief, "
|
||||
"it's important to identify and address the root causes. "
|
||||
"Maintaining good health through proper nutrition, regular exercise, "
|
||||
"and stress management can help prevent chronic headaches."
|
||||
)
|
||||
doc_id = await test_ingest_text_document(client, content=content)
|
||||
|
||||
headers = create_auth_header()
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "Headaches, dehydration",
|
||||
"return_type": "documents",
|
||||
"filters": {"test": True}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
results = list(response.json())
|
||||
assert len(results) > 0
|
||||
assert results[0]["document_id"] == doc_id
|
||||
assert "score" in results[0]
|
||||
assert "metadata" in results[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_documents(client: AsyncClient):
|
||||
"""Test listing documents"""
|
||||
@ -367,15 +302,139 @@ async def test_invalid_document_id(client: AsyncClient):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_query_params(client: AsyncClient):
|
||||
"""Test error handling for invalid query parameters"""
|
||||
async def test_retrieve_chunks(client: AsyncClient):
|
||||
"""Test retrieving document chunks"""
|
||||
# First ingest a document to search
|
||||
doc_id = await test_ingest_text_document(
|
||||
client,
|
||||
content="The quick brown fox jumps over the lazy dog"
|
||||
)
|
||||
|
||||
headers = create_auth_header()
|
||||
# Sleep to allow time for document to be indexed
|
||||
await asyncio.sleep(1)
|
||||
|
||||
response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "jumping fox",
|
||||
"k": 1,
|
||||
"min_score": 0.0
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
results = list(response.json())
|
||||
assert len(results) == 1
|
||||
assert results[0]["score"] > 0.5
|
||||
assert results[0]["document_id"] == doc_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_docs(client: AsyncClient):
|
||||
"""Test retrieving full documents"""
|
||||
# First ingest a document to search
|
||||
content = (
|
||||
"Headaches can significantly impact daily life and wellbeing. "
|
||||
"Common triggers include stress, dehydration, and poor sleep habits. "
|
||||
"While over-the-counter pain relievers may provide temporary relief, "
|
||||
"it's important to identify and address the root causes. "
|
||||
"Maintaining good health through proper nutrition, regular exercise, "
|
||||
"and stress management can help prevent chronic headaches."
|
||||
)
|
||||
doc_id = await test_ingest_text_document(client, content=content)
|
||||
|
||||
headers = create_auth_header()
|
||||
response = await client.post(
|
||||
"/retrieve/docs",
|
||||
json={
|
||||
"query": "Headaches, dehydration",
|
||||
"filters": {"test": True}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
results = list(response.json())
|
||||
assert len(results) > 0
|
||||
assert results[0]["document_id"] == doc_id
|
||||
assert "score" in results[0]
|
||||
assert "metadata" in results[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_completion(client: AsyncClient):
|
||||
"""Test generating completions from context"""
|
||||
# First ingest a document to use as context
|
||||
content = (
|
||||
"The benefits of exercise are numerous. Regular physical activity "
|
||||
"can improve cardiovascular health, strengthen muscles, enhance mental "
|
||||
"wellbeing, and help maintain a healthy weight. Studies show that "
|
||||
"even moderate exercise like walking can significantly reduce the risk "
|
||||
"of various health conditions."
|
||||
)
|
||||
await test_ingest_text_document(client, content=content)
|
||||
|
||||
headers = create_auth_header()
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What are the main benefits of exercise?",
|
||||
"k": 2,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "completion" in result
|
||||
assert "usage" in result
|
||||
assert len(result["completion"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_retrieve_params(client: AsyncClient):
|
||||
"""Test error handling for invalid retrieve parameters"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test empty query
|
||||
response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "", # Empty query
|
||||
"k": 1
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
# Test invalid k
|
||||
response = await client.post(
|
||||
"/retrieve/docs",
|
||||
json={
|
||||
"query": "test",
|
||||
"k": -1 # Invalid k
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_completion_params(client: AsyncClient):
|
||||
"""Test error handling for invalid completion parameters"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test empty query
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "", # Empty query
|
||||
"temperature": 2.0 # Invalid temperature
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
@ -7,18 +7,18 @@ from urllib.parse import urlparse
|
||||
import httpx
|
||||
import jwt
|
||||
|
||||
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult
|
||||
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse
|
||||
|
||||
|
||||
class AsyncDataBridge:
|
||||
"""
|
||||
DataBridge client for document operations.
|
||||
|
||||
|
||||
Args:
|
||||
uri (str): DataBridge URI in the format "databridge://<owner_id>:<token>@<host>"
|
||||
timeout (int, optional): Request timeout in seconds. Defaults to 30.
|
||||
is_local (bool, optional): Whether to connect to a local server. Defaults to False.
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
async with AsyncDataBridge("databridge://owner_id:token@api.databridge.ai") as db:
|
||||
@ -27,7 +27,7 @@ class AsyncDataBridge:
|
||||
"Sample content",
|
||||
metadata={"category": "sample"}
|
||||
)
|
||||
|
||||
|
||||
# Query documents
|
||||
results = await db.query("search query")
|
||||
```
|
||||
@ -52,7 +52,7 @@ class AsyncDataBridge:
|
||||
# Split host and auth parts
|
||||
auth, host = parsed.netloc.split('@')
|
||||
self._owner_id, self._auth_token = auth.split(':')
|
||||
|
||||
|
||||
# Set base URL
|
||||
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
|
||||
|
||||
@ -68,7 +68,7 @@ class AsyncDataBridge:
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated HTTP request"""
|
||||
headers = {"Authorization": f"Bearer {self._auth_token}"}
|
||||
|
||||
|
||||
if not files:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
@ -90,14 +90,14 @@ class AsyncDataBridge:
|
||||
) -> Document:
|
||||
"""
|
||||
Ingest a text document into DataBridge.
|
||||
|
||||
|
||||
Args:
|
||||
content: Text content to ingest
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Metadata of the ingested document
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
doc = await db.ingest_text(
|
||||
@ -130,16 +130,16 @@ class AsyncDataBridge:
|
||||
) -> Document:
|
||||
"""
|
||||
Ingest a file document into DataBridge.
|
||||
|
||||
|
||||
Args:
|
||||
file: File to ingest (path string, bytes, file object, or Path)
|
||||
filename: Name of the file
|
||||
content_type: MIME type (optional, will be guessed if not provided)
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Metadata of the ingested document
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# From file path
|
||||
@ -149,7 +149,7 @@ class AsyncDataBridge:
|
||||
content_type="application/pdf",
|
||||
metadata={"department": "research"}
|
||||
)
|
||||
|
||||
|
||||
# From file object
|
||||
with open("document.pdf", "rb") as f:
|
||||
doc = await db.ingest_file(f, "document.pdf")
|
||||
@ -189,58 +189,125 @@ class AsyncDataBridge:
|
||||
if isinstance(file, (str, Path)):
|
||||
file_obj.close()
|
||||
|
||||
async def query(
|
||||
async def retrieve_chunks(
|
||||
self,
|
||||
query: str,
|
||||
return_type: str = "chunks",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> Union[List[ChunkResult], List[DocumentResult]]:
|
||||
) -> List[ChunkResult]:
|
||||
"""
|
||||
Query documents in DataBridge.
|
||||
Search for relevant chunks.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
return_type: Type of results ("chunks" or "documents")
|
||||
filters: Optional metadata filters
|
||||
k: Number of results (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
|
||||
Returns:
|
||||
List[ChunkResult] or List[DocumentResult] depending on return_type
|
||||
List[ChunkResult]
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Query for chunks
|
||||
chunks = await db.query(
|
||||
chunks = await db.retrieve_chunks(
|
||||
"What are the key findings?",
|
||||
return_type="chunks",
|
||||
filters={"department": "research"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score
|
||||
}
|
||||
|
||||
# Query for documents
|
||||
docs = await db.query(
|
||||
response = await self._request("POST", "retrieve/chunks", request)
|
||||
return [ChunkResult(**r) for r in response]
|
||||
|
||||
async def retrieve_docs(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> List[DocumentResult]:
|
||||
"""
|
||||
Retrieve relevant documents.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
filters: Optional metadata filters
|
||||
k: Number of results (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
|
||||
Returns:
|
||||
List[DocumentResult]
|
||||
|
||||
Example:
|
||||
```python
|
||||
docs = await db.retrieve_docs(
|
||||
"machine learning",
|
||||
return_type="documents",
|
||||
k=5
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"return_type": return_type,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score
|
||||
}
|
||||
|
||||
response = await self._request("POST", "query", request)
|
||||
|
||||
if return_type == "chunks":
|
||||
return [ChunkResult(**r) for r in response]
|
||||
response = await self._request("POST", "retrieve/docs", request)
|
||||
return [DocumentResult(**r) for r in response]
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Generate completion using relevant chunks as context.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
filters: Optional metadata filters
|
||||
k: Number of chunks to use as context (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
max_tokens: Maximum tokens in completion
|
||||
temperature: Model temperature
|
||||
|
||||
Returns:
|
||||
CompletionResponse
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = await db.query(
|
||||
"What are the key findings about customer satisfaction?",
|
||||
filters={"department": "research"},
|
||||
temperature=0.7
|
||||
)
|
||||
print(response.completion)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
response = await self._request("POST", "query", request)
|
||||
return CompletionResponse(**response)
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
skip: int = 0,
|
||||
@ -249,7 +316,7 @@ class AsyncDataBridge:
|
||||
) -> List[Document]:
|
||||
"""
|
||||
List accessible documents.
|
||||
|
||||
|
||||
Args:
|
||||
skip: Number of documents to skip
|
||||
limit: Maximum number of documents to return
|
||||
@ -257,12 +324,12 @@ class AsyncDataBridge:
|
||||
|
||||
Returns:
|
||||
List[Document]: List of accessible documents
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Get first page
|
||||
docs = await db.list_documents(limit=10)
|
||||
|
||||
|
||||
# Get next page
|
||||
next_page = await db.list_documents(skip=10, limit=10, filters={"department": "research"})
|
||||
```
|
||||
@ -276,13 +343,13 @@ class AsyncDataBridge:
|
||||
async def get_document(self, document_id: str) -> Document:
|
||||
"""
|
||||
Get document metadata by ID.
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Document metadata
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
doc = await db.get_document("doc_123")
|
||||
|
@ -53,3 +53,9 @@ class DocumentResult(BaseModel):
|
||||
document_id: str = Field(..., description="Document ID")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
||||
content: DocumentContent = Field(..., description="Document content or URL")
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Completion response model"""
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
|
@ -7,18 +7,18 @@ from urllib.parse import urlparse
|
||||
import jwt
|
||||
import requests
|
||||
|
||||
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult
|
||||
from .models import Document, IngestTextRequest, ChunkResult, DocumentResult, CompletionResponse
|
||||
|
||||
|
||||
class DataBridge:
|
||||
"""
|
||||
DataBridge client for document operations.
|
||||
|
||||
|
||||
Args:
|
||||
uri (str): DataBridge URI in the format "databridge://<owner_id>:<token>@<host>"
|
||||
timeout (int, optional): Request timeout in seconds. Defaults to 30.
|
||||
is_local (bool, optional): Whether connecting to local development server. Defaults to False.
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
with DataBridge("databridge://owner_id:token@api.databridge.ai") as db:
|
||||
@ -27,7 +27,7 @@ class DataBridge:
|
||||
"Sample content",
|
||||
metadata={"category": "sample"}
|
||||
)
|
||||
|
||||
|
||||
# Query documents
|
||||
results = db.query("search query")
|
||||
```
|
||||
@ -50,7 +50,7 @@ class DataBridge:
|
||||
# Split host and auth parts
|
||||
auth, host = parsed.netloc.split('@')
|
||||
self._owner_id, self._auth_token = auth.split(':')
|
||||
|
||||
|
||||
# Set base URL
|
||||
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
|
||||
|
||||
@ -66,7 +66,7 @@ class DataBridge:
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated HTTP request"""
|
||||
headers = {"Authorization": f"Bearer {self._auth_token}"}
|
||||
|
||||
|
||||
if not files:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
@ -89,14 +89,14 @@ class DataBridge:
|
||||
) -> Document:
|
||||
"""
|
||||
Ingest a text document into DataBridge.
|
||||
|
||||
|
||||
Args:
|
||||
content: Text content to ingest
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Metadata of the ingested document
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
doc = db.ingest_text(
|
||||
@ -129,16 +129,16 @@ class DataBridge:
|
||||
) -> Document:
|
||||
"""
|
||||
Ingest a file document into DataBridge.
|
||||
|
||||
|
||||
Args:
|
||||
file: File to ingest (path string, bytes, file object, or Path)
|
||||
filename: Name of the file
|
||||
content_type: MIME type (optional, will be guessed if not provided)
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Metadata of the ingested document
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# From file path
|
||||
@ -148,7 +148,7 @@ class DataBridge:
|
||||
content_type="application/pdf",
|
||||
metadata={"department": "research"}
|
||||
)
|
||||
|
||||
|
||||
# From file object
|
||||
with open("document.pdf", "rb") as f:
|
||||
doc = db.ingest_file(f, "document.pdf")
|
||||
@ -188,58 +188,125 @@ class DataBridge:
|
||||
if isinstance(file, (str, Path)):
|
||||
file_obj.close()
|
||||
|
||||
def query(
|
||||
def retrieve_chunks(
|
||||
self,
|
||||
query: str,
|
||||
return_type: str = "chunks",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> Union[List[ChunkResult], List[DocumentResult]]:
|
||||
) -> List[ChunkResult]:
|
||||
"""
|
||||
Query documents in DataBridge.
|
||||
Retrieve relevant chunks.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
return_type: Type of results ("chunks" or "documents")
|
||||
filters: Optional metadata filters
|
||||
k: Number of results (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
|
||||
Returns:
|
||||
List[ChunkResult] or List[DocumentResult] depending on return_type
|
||||
List[ChunkResult]
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Query for chunks
|
||||
chunks = db.query(
|
||||
chunks = db.retrieve_chunks(
|
||||
"What are the key findings?",
|
||||
return_type="chunks",
|
||||
filters={"department": "research"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score
|
||||
}
|
||||
|
||||
# Query for documents
|
||||
docs = db.query(
|
||||
response = self._request("POST", "search/chunks", request)
|
||||
return [ChunkResult(**r) for r in response]
|
||||
|
||||
def retrieve_docs(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0
|
||||
) -> List[DocumentResult]:
|
||||
"""
|
||||
Retrieve relevant documents.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
filters: Optional metadata filters
|
||||
k: Number of results (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
|
||||
Returns:
|
||||
List[DocumentResult]
|
||||
|
||||
Example:
|
||||
```python
|
||||
docs = db.retrieve_docs(
|
||||
"machine learning",
|
||||
return_type="documents",
|
||||
k=5
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"return_type": return_type,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score
|
||||
}
|
||||
|
||||
response = self._request("POST", "query", request)
|
||||
|
||||
if return_type == "chunks":
|
||||
return [ChunkResult(**r) for r in response]
|
||||
response = self._request("POST", "retrieve/docs", request)
|
||||
return [DocumentResult(**r) for r in response]
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Generate completion using relevant chunks as context.
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
filters: Optional metadata filters
|
||||
k: Number of chunks to use as context (default: 4)
|
||||
min_score: Minimum similarity threshold (default: 0.0)
|
||||
max_tokens: Maximum tokens in completion
|
||||
temperature: Model temperature
|
||||
|
||||
Returns:
|
||||
CompletionResponse
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = db.query(
|
||||
"What are the key findings about customer satisfaction?",
|
||||
filters={"department": "research"},
|
||||
temperature=0.7
|
||||
)
|
||||
print(response.completion)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
response = self._request("POST", "query", request)
|
||||
return CompletionResponse(**response)
|
||||
|
||||
def list_documents(
|
||||
self,
|
||||
skip: int = 0,
|
||||
@ -248,7 +315,7 @@ class DataBridge:
|
||||
) -> List[Document]:
|
||||
"""
|
||||
List accessible documents.
|
||||
|
||||
|
||||
Args:
|
||||
skip: Number of documents to skip
|
||||
limit: Maximum number of documents to return
|
||||
@ -256,12 +323,12 @@ class DataBridge:
|
||||
|
||||
Returns:
|
||||
List[Document]: List of accessible documents
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Get first page
|
||||
docs = db.list_documents(limit=10)
|
||||
|
||||
|
||||
# Get next page
|
||||
next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"})
|
||||
```
|
||||
@ -275,13 +342,13 @@ class DataBridge:
|
||||
def get_document(self, document_id: str) -> Document:
|
||||
"""
|
||||
Get document metadata by ID.
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Document metadata
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
doc = db.get_document("doc_123")
|
||||
|
@ -15,6 +15,7 @@ dependencies = [
|
||||
"httpx>=0.24.0",
|
||||
"pyjwt>=2.0.0",
|
||||
"pydantic==2.10.3",
|
||||
"requests>=2.32.3",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
|
Loading…
x
Reference in New Issue
Block a user