228 lines
6.9 KiB
Python
Raw Normal View History

from typing import Dict, Any, List, Optional, Union
import httpx
from urllib.parse import urlparse
import jwt
from datetime import datetime, UTC
import asyncio
from dataclasses import dataclass
from .exceptions import AuthenticationError
2024-11-18 18:41:23 -05:00
from pydantic import BaseModel, Field
import logging
logger = logging.getLogger(__name__)
@dataclass
class QueryResult:
"""Structured query result"""
content: str
doc_id: str
score: Optional[float]
metadata: Dict[str, Any]
2024-11-18 18:41:23 -05:00
# Request/Response Models
class Document(BaseModel):
id: str
name: str
type: str
source: str
uploaded_at: str
size: str
redaction_level: str
stats: Dict[str, Union[int, str]] = Field(
default_factory=lambda: {
"ai_queries": 0,
"time_saved": "0h",
"last_accessed": ""
}
)
accessed_by: List[Dict[str, str]] = Field(default_factory=list)
sensitive_content: Optional[List[str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
s3_bucket: Optional[str] = None
s3_key: Optional[str] = None
class DataBridge:
"""
DataBridge client for document ingestion and querying.
Usage:
db = DataBridge("databridge://owner123:token@databridge.local")
doc_id = await db.ingest_document("content", {"title": "My Doc"})
results = await db.query("What is...")
"""
def __init__(
self,
uri: str,
timeout: int = 30,
max_retries: int = 3
):
self._timeout = timeout
self._max_retries = max_retries
self._client = httpx.AsyncClient(timeout=timeout)
self._setup_auth(uri)
def _setup_auth(self, uri: str) -> None:
"""Setup authentication from URI"""
try:
parsed = urlparse(uri)
if not parsed.netloc:
raise ValueError("Invalid URI format")
2024-11-17 15:54:27 -05:00
split_uri = parsed.netloc.split('@')
self._base_url = f"{"http" if "localhost" in split_uri[1] else "https"}://{split_uri[1]}"
auth_parts = split_uri[0].split(':')
if len(auth_parts) != 2:
raise ValueError("URI must include owner_id and auth_token")
if '.' in auth_parts[0]:
self._owner_id = auth_parts[0] # dev_id.app_id format
self._auth_token = auth_parts[1]
else:
self._owner_id = auth_parts[0] # eu_id format
self._auth_token = auth_parts[1]
# Validate token structure (not signature)
try:
decoded = jwt.decode(self._auth_token, options={"verify_signature": False})
self._token_expiry = datetime.fromtimestamp(decoded['exp'], UTC)
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid auth token format: {str(e)}")
except Exception as e:
raise AuthenticationError(f"Failed to setup authentication: {str(e)}")
async def _make_request(
self,
method: str,
endpoint: str,
data: Dict[str, Any] = None,
retry_count: int = 0
) -> Dict[str, Any]:
"""Make authenticated HTTP request with retries"""
# if datetime.now(UTC) > self._token_expiry:
# raise AuthenticationError("Authentication token has expired")
headers = {
"X-Owner-ID": self._owner_id,
"X-Auth-Token": self._auth_token,
"Content-Type": "application/json"
}
try:
response = await self._client.request(
method,
2024-11-17 15:54:27 -05:00
f"{self._base_url}/{endpoint.lstrip('/')}",
json=data,
headers=headers
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError("Authentication failed: " + str(e))
elif e.response.status_code >= 500 and retry_count < self._max_retries:
await asyncio.sleep(2 ** retry_count) # Exponential backoff
return await self._make_request(method, endpoint, data, retry_count + 1)
else:
raise ConnectionError(f"Request failed: {e.response.text}")
except Exception as e:
raise ConnectionError(f"Request failed: {str(e)}")
async def ingest_document(
self,
content: Union[str, bytes],
2024-11-18 18:41:23 -05:00
metadata: Optional[Dict[str, Any]] = None,
filename: Optional[str] = None
) -> str:
"""
Ingest a document into DataBridge.
Args:
content: Document content (string or bytes)
metadata: Optional document metadata
2024-11-18 18:41:23 -05:00
filename: Optional filename - defaults to doc_id if not provided
Returns:
Document ID of the ingested document
"""
2024-11-18 18:41:23 -05:00
metadata = metadata or {}
if filename:
metadata["filename"] = filename
if isinstance(content, bytes):
import base64
content = base64.b64encode(content).decode()
2024-11-18 18:41:23 -05:00
metadata = metadata
metadata["is_base64"] = True
response = await self._make_request(
"POST",
"ingest",
{
"content": content,
"metadata": metadata
}
)
return response["document_id"]
async def query(
self,
query: str,
k: int = 4,
filters: Optional[Dict[str, Any]] = None
) -> List[QueryResult]:
"""
Query documents in DataBridge.
Args:
query: Query string
k: Number of results to return
filters: Optional metadata filters
Returns:
List of QueryResult objects
"""
response = await self._make_request(
"POST",
"query",
{
"query": query,
"k": k,
"filters": filters
}
)
return [
QueryResult(
content=result["content"],
doc_id=result["doc_id"],
score=result.get("score"),
metadata=result.get("metadata", {})
)
for result in response["results"]
]
2024-11-18 18:41:23 -05:00
async def get_documents(self) -> List[Document]:
"""Get all documents"""
response = await self._make_request("GET", "documents")
return [Document(**doc) for doc in response]
async def close(self):
"""Close the HTTP client"""
await self._client.aclose()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.close()
def __repr__(self) -> str:
"""Safe string representation"""
return f"DataBridge(owner_id='{self._owner_id}')"