2024-11-16 01:48:15 -05:00

199 lines
5.9 KiB
Python

from typing import Dict, Any, List, Optional, Union
import httpx
from urllib.parse import urlparse
import jwt
from datetime import datetime, UTC
import asyncio
from dataclasses import dataclass
from .exceptions import AuthenticationError
from .types import ContentType
import logging
logger = logging.getLogger(__name__)
@dataclass
class QueryResult:
"""Structured query result"""
content: str
doc_id: str
chunk_id: str
score: Optional[float]
metadata: Dict[str, Any]
class DataBridge:
"""
DataBridge client for document ingestion and querying.
Usage:
db = DataBridge("databridge://owner123:token@databridge.local")
doc_id = await db.ingest_document("content", {"title": "My Doc"})
results = await db.query("What is...")
"""
def __init__(
self,
uri: str,
base_url: str = "https://api.databridge.ai/v1",
timeout: int = 30,
max_retries: int = 3
):
self._base_url = base_url.rstrip('/')
self._timeout = timeout
self._max_retries = max_retries
self._client = httpx.AsyncClient(timeout=timeout)
self._setup_auth(uri)
def _setup_auth(self, uri: str) -> None:
"""Setup authentication from URI"""
try:
parsed = urlparse(uri)
if not parsed.netloc:
raise ValueError("Invalid URI format")
auth_parts = parsed.netloc.split('@')[0].split(':')
if len(auth_parts) != 2:
raise ValueError("URI must include owner_id and auth_token")
self._owner_id = auth_parts[0]
self._auth_token = auth_parts[1]
# Validate token structure (not signature)
try:
decoded = jwt.decode(self._auth_token, options={"verify_signature": False})
self._token_expiry = datetime.fromtimestamp(decoded['exp'], UTC)
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid auth token format: {str(e)}")
except Exception as e:
raise AuthenticationError(f"Failed to setup authentication: {str(e)}")
async def _make_request(
self,
method: str,
endpoint: str,
data: Dict[str, Any] = None,
retry_count: int = 0
) -> Dict[str, Any]:
"""Make authenticated HTTP request with retries"""
# if datetime.now(UTC) > self._token_expiry:
# raise AuthenticationError("Authentication token has expired")
headers = {
"X-Owner-ID": self._owner_id,
"X-Auth-Token": self._auth_token,
"Content-Type": "application/json"
}
try:
response = await self._client.request(
method,
f"http://localhost:8000/{endpoint.lstrip('/')}",
json=data,
headers=headers
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError("Authentication failed: " + str(e))
elif e.response.status_code >= 500 and retry_count < self._max_retries:
await asyncio.sleep(2 ** retry_count) # Exponential backoff
return await self._make_request(method, endpoint, data, retry_count + 1)
else:
raise ConnectionError(f"Request failed: {e.response.text}")
except Exception as e:
raise ConnectionError(f"Request failed: {str(e)}")
async def ingest_document(
self,
content: Union[str, bytes],
metadata: Optional[Dict[str, Any]] = None,
content_type: ContentType = ContentType.TEXT
) -> str:
"""
Ingest a document into DataBridge.
Args:
content: Document content (string or bytes)
metadata: Optional document metadata
content_type: Type of the content being ingested
Returns:
Document ID of the ingested document
"""
if isinstance(content, bytes):
import base64
content = base64.b64encode(content).decode()
metadata = metadata or {}
metadata["is_base64"] = True
metadata = metadata or {}
metadata["content_type"] = content_type
response = await self._make_request(
"POST",
"ingest",
{
"content": content,
"metadata": metadata
}
)
return response["document_id"]
async def query(
self,
query: str,
k: int = 4,
filters: Optional[Dict[str, Any]] = None
) -> List[QueryResult]:
"""
Query documents in DataBridge.
Args:
query: Query string
k: Number of results to return
filters: Optional metadata filters
Returns:
List of QueryResult objects
"""
response = await self._make_request(
"POST",
"query",
{
"query": query,
"k": k,
"filters": filters
}
)
return [
QueryResult(
content=result["content"],
doc_id=result["doc_id"],
chunk_id=result["chunk_id"],
score=result.get("score"),
metadata=result.get("metadata", {})
)
for result in response["results"]
]
async def close(self):
"""Close the HTTP client"""
await self._client.aclose()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.close()
def __repr__(self) -> str:
"""Safe string representation"""
return f"DataBridge(owner_id='{self._owner_id}')"