mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add open telemetry and shell (#5)
This commit is contained in:
parent
3ad55129b7
commit
3e4a9999ad
135
core/api.py
135
core/api.py
@ -5,6 +5,7 @@ from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
import logging
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from core.completion.openai_completion import OpenAICompletionModel
|
||||
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||
from core.models.request import (
|
||||
@ -18,6 +19,7 @@ from core.parser.combined_parser import CombinedParser
|
||||
from core.completion.base_completion import CompletionResponse
|
||||
from core.parser.unstructured_parser import UnstructuredAPIParser
|
||||
from core.services.document_service import DocumentService
|
||||
from core.services.telemetry import TelemetryService
|
||||
from core.config import get_settings
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
@ -29,6 +31,12 @@ from core.completion.ollama_completion import OllamaCompletionModel
|
||||
app = FastAPI(title="DataBridge API")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize telemetry
|
||||
telemetry = TelemetryService()
|
||||
|
||||
# Add OpenTelemetry instrumentation
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@ -88,7 +96,7 @@ match settings.PARSER_PROVIDER:
|
||||
)
|
||||
case "unstructured":
|
||||
parser = UnstructuredAPIParser(
|
||||
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
|
||||
api_key=settings.UNSTRUCTURED_API_KEY,
|
||||
chunk_size=settings.CHUNK_SIZE,
|
||||
chunk_overlap=settings.CHUNK_OVERLAP,
|
||||
)
|
||||
@ -169,7 +177,13 @@ async def ingest_text(
|
||||
) -> Document:
|
||||
"""Ingest a text document."""
|
||||
try:
|
||||
return await document_service.ingest_text(request, auth)
|
||||
async with telemetry.track_operation(
|
||||
operation_type="ingest_text",
|
||||
user_id=auth.entity_id,
|
||||
tokens_used=len(request.content.split()), # Approximate token count
|
||||
metadata=request.metadata.model_dump() if request.metadata else None,
|
||||
):
|
||||
return await document_service.ingest_text(request, auth)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
@ -177,14 +191,23 @@ async def ingest_text(
|
||||
@app.post("/ingest/file", response_model=Document)
|
||||
async def ingest_file(
|
||||
file: UploadFile,
|
||||
metadata: str = Form("{}"), # JSON string of metadata
|
||||
metadata: str = Form("{}"),
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
try:
|
||||
metadata_dict = json.loads(metadata)
|
||||
doc = await document_service.ingest_file(file, metadata_dict, auth)
|
||||
return doc # TODO: Might be lighter on network to just send the document ID.
|
||||
async with telemetry.track_operation(
|
||||
operation_type="ingest_file",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"filename": file.filename,
|
||||
"content_type": file.content_type,
|
||||
"metadata": metadata_dict,
|
||||
},
|
||||
):
|
||||
doc = await document_service.ingest_file(file, metadata_dict, auth)
|
||||
return doc
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except json.JSONDecodeError:
|
||||
@ -194,17 +217,27 @@ async def ingest_file(
|
||||
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
||||
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve relevant chunks."""
|
||||
return await document_service.retrieve_chunks(
|
||||
request.query, auth, request.filters, request.k, request.min_score
|
||||
)
|
||||
async with telemetry.track_operation(
|
||||
operation_type="retrieve_chunks",
|
||||
user_id=auth.entity_id,
|
||||
metadata=request.model_dump(),
|
||||
):
|
||||
return await document_service.retrieve_chunks(
|
||||
request.query, auth, request.filters, request.k, request.min_score
|
||||
)
|
||||
|
||||
|
||||
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
||||
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve relevant documents."""
|
||||
return await document_service.retrieve_docs(
|
||||
request.query, auth, request.filters, request.k, request.min_score
|
||||
)
|
||||
async with telemetry.track_operation(
|
||||
operation_type="retrieve_docs",
|
||||
user_id=auth.entity_id,
|
||||
metadata=request.model_dump(),
|
||||
):
|
||||
return await document_service.retrieve_docs(
|
||||
request.query, auth, request.filters, request.k, request.min_score
|
||||
)
|
||||
|
||||
|
||||
@app.post("/query", response_model=CompletionResponse)
|
||||
@ -212,15 +245,27 @@ async def query_completion(
|
||||
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""Generate completion using relevant chunks as context."""
|
||||
return await document_service.query(
|
||||
request.query,
|
||||
auth,
|
||||
request.filters,
|
||||
request.k,
|
||||
request.min_score,
|
||||
request.max_tokens,
|
||||
request.temperature,
|
||||
)
|
||||
async with telemetry.track_operation(
|
||||
operation_type="query",
|
||||
user_id=auth.entity_id,
|
||||
metadata=request.model_dump(),
|
||||
) as span:
|
||||
response = await document_service.query(
|
||||
request.query,
|
||||
auth,
|
||||
request.filters,
|
||||
request.k,
|
||||
request.min_score,
|
||||
request.max_tokens,
|
||||
request.temperature,
|
||||
)
|
||||
if isinstance(response, dict) and "usage" in response:
|
||||
usage = response["usage"]
|
||||
if isinstance(usage, dict):
|
||||
span.set_attribute("tokens.completion", usage.get("completion_tokens", 0))
|
||||
span.set_attribute("tokens.prompt", usage.get("prompt_tokens", 0))
|
||||
span.set_attribute("tokens.total", usage.get("total_tokens", 0))
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/documents", response_model=List[Document])
|
||||
@ -246,3 +291,53 @@ async def get_document(document_id: str, auth: AuthContext = Depends(verify_toke
|
||||
except HTTPException as e:
|
||||
logger.error(f"Error getting document: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
# Usage tracking endpoints
|
||||
@app.get("/usage/stats")
|
||||
async def get_usage_stats(auth: AuthContext = Depends(verify_token)) -> Dict[str, int]:
|
||||
"""Get usage statistics for the authenticated user."""
|
||||
async with telemetry.track_operation(operation_type="get_usage_stats", user_id=auth.entity_id):
|
||||
if not auth.permissions or "admin" not in auth.permissions:
|
||||
return telemetry.get_user_usage(auth.entity_id)
|
||||
return telemetry.get_user_usage(auth.entity_id)
|
||||
|
||||
|
||||
@app.get("/usage/recent")
|
||||
async def get_recent_usage(
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
operation_type: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
status: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""Get recent usage records."""
|
||||
async with telemetry.track_operation(
|
||||
operation_type="get_recent_usage",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"operation_type": operation_type,
|
||||
"since": since.isoformat() if since else None,
|
||||
"status": status,
|
||||
},
|
||||
):
|
||||
if not auth.permissions or "admin" not in auth.permissions:
|
||||
records = telemetry.get_recent_usage(
|
||||
user_id=auth.entity_id, operation_type=operation_type, since=since, status=status
|
||||
)
|
||||
else:
|
||||
records = telemetry.get_recent_usage(
|
||||
operation_type=operation_type, since=since, status=status
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"timestamp": record.timestamp,
|
||||
"operation_type": record.operation_type,
|
||||
"tokens_used": record.tokens_used,
|
||||
"user_id": record.user_id,
|
||||
"duration_ms": record.duration_ms,
|
||||
"status": record.status,
|
||||
"metadata": record.metadata,
|
||||
}
|
||||
for record in records
|
||||
]
|
||||
|
298
core/services/telemetry.py
Normal file
298
core/services/telemetry.py
Normal file
@ -0,0 +1,298 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from opentelemetry import trace, metrics
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.metrics.export import (
|
||||
PeriodicExportingMetricReader,
|
||||
MetricExporter,
|
||||
AggregationTemporality,
|
||||
MetricsData,
|
||||
)
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||
|
||||
|
||||
class FileSpanExporter:
|
||||
def __init__(self, log_dir: str):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.trace_file = self.log_dir / "traces.log"
|
||||
|
||||
def export(self, spans):
|
||||
with open(self.trace_file, "a") as f:
|
||||
for span in spans:
|
||||
f.write(json.dumps(self._format_span(span)) + "\n")
|
||||
return True
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def _format_span(self, span):
|
||||
return {
|
||||
"name": span.name,
|
||||
"trace_id": format(span.context.trace_id, "x"),
|
||||
"span_id": format(span.context.span_id, "x"),
|
||||
"parent_id": format(span.parent.span_id, "x") if span.parent else None,
|
||||
"start_time": span.start_time,
|
||||
"end_time": span.end_time,
|
||||
"attributes": dict(span.attributes),
|
||||
"status": span.status.status_code.name,
|
||||
}
|
||||
|
||||
|
||||
class FileMetricExporter(MetricExporter):
|
||||
"""File metric exporter for OpenTelemetry."""
|
||||
|
||||
def __init__(self, log_dir: str):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.metrics_file = self.log_dir / "metrics.log"
|
||||
super().__init__()
|
||||
|
||||
def export(self, metrics_data: MetricsData, **kwargs) -> bool:
|
||||
"""Export metrics data to a file.
|
||||
|
||||
Args:
|
||||
metrics_data: The metrics data to export.
|
||||
|
||||
Returns:
|
||||
True if the export was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with open(self.metrics_file, "a") as f:
|
||||
for resource_metrics in metrics_data.resource_metrics:
|
||||
for scope_metrics in resource_metrics.scope_metrics:
|
||||
for metric in scope_metrics.metrics:
|
||||
f.write(json.dumps(self._format_metric(metric)) + "\n")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def shutdown(self, timeout_millis: float = 30_000, **kwargs) -> bool:
|
||||
"""Shuts down the exporter.
|
||||
|
||||
Args:
|
||||
timeout_millis: Time to wait for the export to complete in milliseconds.
|
||||
|
||||
Returns:
|
||||
True if the shutdown succeeded, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def force_flush(self, timeout_millis: float = 10_000) -> bool:
|
||||
"""Force flush the exporter.
|
||||
|
||||
Args:
|
||||
timeout_millis: Time to wait for the flush to complete in milliseconds.
|
||||
|
||||
Returns:
|
||||
True if the flush succeeded, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _preferred_temporality(self) -> Dict:
|
||||
"""Returns the preferred temporality for each instrument kind."""
|
||||
return {
|
||||
"counter": AggregationTemporality.CUMULATIVE,
|
||||
"up_down_counter": AggregationTemporality.CUMULATIVE,
|
||||
"observable_counter": AggregationTemporality.CUMULATIVE,
|
||||
"observable_up_down_counter": AggregationTemporality.CUMULATIVE,
|
||||
"histogram": AggregationTemporality.CUMULATIVE,
|
||||
"observable_gauge": AggregationTemporality.CUMULATIVE,
|
||||
}
|
||||
|
||||
def _format_metric(self, metric):
|
||||
return {
|
||||
"name": metric.name,
|
||||
"description": metric.description,
|
||||
"unit": metric.unit,
|
||||
"data": self._format_data(metric.data),
|
||||
}
|
||||
|
||||
def _format_data(self, data):
|
||||
if hasattr(data, "data_points"):
|
||||
return {
|
||||
"data_points": [
|
||||
{
|
||||
"attributes": dict(point.attributes),
|
||||
"value": point.value if hasattr(point, "value") else None,
|
||||
"count": point.count if hasattr(point, "count") else None,
|
||||
"sum": point.sum if hasattr(point, "sum") else None,
|
||||
"timestamp": point.time_unix_nano,
|
||||
}
|
||||
for point in data.data_points
|
||||
]
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
timestamp: datetime
|
||||
operation_type: str
|
||||
tokens_used: int
|
||||
user_id: str
|
||||
duration_ms: float
|
||||
status: str
|
||||
metadata: Optional[Dict] = None
|
||||
|
||||
|
||||
class TelemetryService:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self):
|
||||
self._usage_records: List[UsageRecord] = []
|
||||
self._user_totals = defaultdict(lambda: defaultdict(int))
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Initialize OpenTelemetry
|
||||
resource = Resource.create({"service.name": "databridge-core"})
|
||||
|
||||
# Create logs directory
|
||||
log_dir = Path("logs/telemetry")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize tracing
|
||||
tracer_provider = TracerProvider(resource=resource)
|
||||
|
||||
# Use file exporter for local development
|
||||
if os.getenv("ENVIRONMENT", "development") == "development":
|
||||
span_processor = BatchSpanProcessor(FileSpanExporter(str(log_dir)))
|
||||
else:
|
||||
span_processor = BatchSpanProcessor(OTLPSpanExporter())
|
||||
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
self.tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Initialize metrics
|
||||
if os.getenv("ENVIRONMENT", "development") == "development":
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
FileMetricExporter(str(log_dir)),
|
||||
export_interval_millis=60000, # Export every minute
|
||||
)
|
||||
else:
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
|
||||
# Create metrics
|
||||
self.operation_counter = self.meter.create_counter(
|
||||
"databridge.operations",
|
||||
description="Number of operations performed",
|
||||
)
|
||||
self.token_counter = self.meter.create_counter(
|
||||
"databridge.tokens",
|
||||
description="Number of tokens processed",
|
||||
)
|
||||
self.operation_duration = self.meter.create_histogram(
|
||||
"databridge.operation.duration",
|
||||
description="Duration of operations",
|
||||
unit="ms",
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def track_operation(
|
||||
self,
|
||||
operation_type: str,
|
||||
user_id: str,
|
||||
tokens_used: int = 0,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Context manager for tracking operations with both usage metrics and OpenTelemetry
|
||||
"""
|
||||
start_time = time.time()
|
||||
status = "success"
|
||||
current_span = trace.get_current_span()
|
||||
|
||||
try:
|
||||
# Add operation attributes to the current span
|
||||
current_span.set_attribute("operation.type", operation_type)
|
||||
current_span.set_attribute("user.id", user_id)
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
current_span.set_attribute(f"metadata.{key}", str(value))
|
||||
|
||||
yield current_span
|
||||
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
current_span.set_status(Status(StatusCode.ERROR))
|
||||
current_span.record_exception(e)
|
||||
raise
|
||||
finally:
|
||||
duration = (time.time() - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
# Record metrics
|
||||
self.operation_counter.add(1, {"operation": operation_type, "status": status})
|
||||
if tokens_used > 0:
|
||||
self.token_counter.add(tokens_used, {"operation": operation_type})
|
||||
self.operation_duration.record(duration, {"operation": operation_type})
|
||||
|
||||
# Record usage
|
||||
record = UsageRecord(
|
||||
timestamp=datetime.now(),
|
||||
operation_type=operation_type,
|
||||
tokens_used=tokens_used,
|
||||
user_id=user_id,
|
||||
duration_ms=duration,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._usage_records.append(record)
|
||||
self._user_totals[user_id][operation_type] += tokens_used
|
||||
|
||||
def get_user_usage(self, user_id: str) -> Dict[str, int]:
|
||||
"""Get usage statistics for a user."""
|
||||
with self._lock:
|
||||
return dict(self._user_totals[user_id])
|
||||
|
||||
def get_recent_usage(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
operation_type: Optional[str] = None,
|
||||
since: Optional[datetime] = None,
|
||||
status: Optional[str] = None,
|
||||
) -> List[UsageRecord]:
|
||||
"""Get recent usage records with optional filtering."""
|
||||
with self._lock:
|
||||
records = self._usage_records.copy()
|
||||
|
||||
# Apply filters
|
||||
if user_id:
|
||||
records = [r for r in records if r.user_id == user_id]
|
||||
if operation_type:
|
||||
records = [r for r in records if r.operation_type == operation_type]
|
||||
if since:
|
||||
records = [r for r in records if r.timestamp >= since]
|
||||
if status:
|
||||
records = [r for r in records if r.status == status]
|
||||
|
||||
return records
|
@ -243,3 +243,7 @@ xlrd==2.0.1
|
||||
XlsxWriter==3.2.0
|
||||
xxhash==3.4.1
|
||||
yarl==1.9.4
|
||||
opentelemetry-api>=1.21.0
|
||||
opentelemetry-sdk>=1.21.0
|
||||
opentelemetry-instrumentation-fastapi>=0.42b0
|
||||
opentelemetry-exporter-otlp>=1.21.0
|
||||
|
@ -205,7 +205,7 @@ class DataBridge:
|
||||
"""
|
||||
request = {"query": query, "filters": filters, "k": k, "min_score": min_score}
|
||||
|
||||
response = self._request("POST", "search/chunks", request)
|
||||
response = self._request("POST", "retrieve/chunks", request)
|
||||
return [ChunkResult(**r) for r in response]
|
||||
|
||||
def retrieve_docs(
|
||||
|
126
shell.py
Normal file
126
shell.py
Normal file
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DataBridge interactive CLI.
|
||||
Assumes a DataBridge server is running.
|
||||
|
||||
Usage:
|
||||
python shell.py <uri>
|
||||
Example: python shell.py "http://test_user:token@localhost:8000"
|
||||
|
||||
This provides the exact same interface as the Python SDK:
|
||||
db.ingest_text("content", metadata={...})
|
||||
db.ingest_file("path/to/file")
|
||||
db.query("what are the key findings?")
|
||||
etc...
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add local SDK to path before other imports
|
||||
_SDK_PATH = str(Path(__file__).parent / "sdks" / "python")
|
||||
if _SDK_PATH not in sys.path:
|
||||
sys.path.insert(0, _SDK_PATH)
|
||||
|
||||
from databridge import DataBridge # noqa: E402
|
||||
|
||||
|
||||
class DB:
|
||||
def __init__(self, uri: str):
|
||||
"""Initialize DataBridge with URI"""
|
||||
# Convert databridge:// to http:// for localhost
|
||||
if "localhost" in uri or "127.0.0.1" in uri:
|
||||
uri = uri.replace("databridge://", "http://")
|
||||
self.uri = uri
|
||||
self._client = DataBridge(self.uri, is_local="localhost" in uri or "127.0.0.1" in uri)
|
||||
|
||||
def ingest_text(self, content: str, metadata: dict = None) -> dict:
|
||||
"""Ingest text content into DataBridge"""
|
||||
doc = self._client.ingest_text(content, metadata=metadata or {})
|
||||
return doc.model_dump()
|
||||
|
||||
def ingest_file(self, file_path: str, metadata: dict = None, content_type: str = None) -> dict:
|
||||
"""Ingest a file into DataBridge"""
|
||||
file_path = Path(file_path)
|
||||
doc = self._client.ingest_file(
|
||||
file_path, filename=file_path.name, content_type=content_type, metadata=metadata or {}
|
||||
)
|
||||
return doc.model_dump()
|
||||
|
||||
def retrieve_chunks(
|
||||
self, query: str, filters: dict = None, k: int = 4, min_score: float = 0.0
|
||||
) -> list:
|
||||
"""Search for relevant chunks"""
|
||||
results = self._client.retrieve_chunks(
|
||||
query, filters=filters or {}, k=k, min_score=min_score
|
||||
)
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
def retrieve_docs(
|
||||
self, query: str, filters: dict = None, k: int = 4, min_score: float = 0.0
|
||||
) -> list:
|
||||
"""Retrieve relevant documents"""
|
||||
results = self._client.retrieve_docs(query, filters=filters or {}, k=k, min_score=min_score)
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict = None,
|
||||
k: int = 4,
|
||||
min_score: float = 0.0,
|
||||
max_tokens: int = None,
|
||||
temperature: float = None,
|
||||
) -> dict:
|
||||
"""Generate completion using relevant chunks as context"""
|
||||
response = self._client.query(
|
||||
query,
|
||||
filters=filters or {},
|
||||
k=k,
|
||||
min_score=min_score,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
def list_documents(self, skip: int = 0, limit: int = 100, filters: dict = None) -> list:
|
||||
"""List accessible documents"""
|
||||
docs = self._client.list_documents(skip=skip, limit=limit, filters=filters or {})
|
||||
return [doc.model_dump() for doc in docs]
|
||||
|
||||
def get_document(self, document_id: str) -> dict:
|
||||
"""Get document metadata by ID"""
|
||||
doc = self._client.get_document(document_id)
|
||||
return doc.model_dump()
|
||||
|
||||
def close(self):
|
||||
"""Close the client connection"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Error: URI argument required")
|
||||
print(__doc__)
|
||||
sys.exit(1)
|
||||
|
||||
# Create DB instance with provided URI
|
||||
db = DB(sys.argv[1])
|
||||
|
||||
# Start an interactive Python shell with 'db' already imported
|
||||
import code
|
||||
import readline # Enable arrow key history
|
||||
import rlcompleter # noqa: F401 # Enable tab completion
|
||||
|
||||
readline.parse_and_bind("tab: complete")
|
||||
|
||||
# Create the interactive shell
|
||||
shell = code.InteractiveConsole(locals())
|
||||
|
||||
# Print welcome message
|
||||
print("\nDataBridge CLI ready to use. The 'db' object is available with all SDK methods.")
|
||||
print("Example: db.ingest_text('hello world')")
|
||||
print("Type help(db) for documentation.")
|
||||
|
||||
# Start the shell
|
||||
shell.interact(banner="")
|
Loading…
x
Reference in New Issue
Block a user