Add open telemetry and shell (#5)

This commit is contained in:
Adityavardhan Agrawal 2024-12-31 10:22:25 +05:30 committed by GitHub
parent 3ad55129b7
commit 3e4a9999ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 544 additions and 21 deletions

View File

@ -5,6 +5,7 @@ from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import (
@ -18,6 +19,7 @@ from core.parser.combined_parser import CombinedParser
from core.completion.base_completion import CompletionResponse
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.config import get_settings
from core.database.mongo_database import MongoDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
@ -29,6 +31,12 @@ from core.completion.ollama_completion import OllamaCompletionModel
app = FastAPI(title="DataBridge API")
logger = logging.getLogger(__name__)
# Initialize telemetry
telemetry = TelemetryService()
# Add OpenTelemetry instrumentation
FastAPIInstrumentor.instrument_app(app)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
@ -88,7 +96,7 @@ match settings.PARSER_PROVIDER:
)
case "unstructured":
parser = UnstructuredAPIParser(
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
api_key=settings.UNSTRUCTURED_API_KEY,
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
)
@ -169,7 +177,13 @@ async def ingest_text(
) -> Document:
"""Ingest a text document."""
try:
return await document_service.ingest_text(request, auth)
async with telemetry.track_operation(
operation_type="ingest_text",
user_id=auth.entity_id,
tokens_used=len(request.content.split()), # Approximate token count
metadata=request.metadata.model_dump() if request.metadata else None,
):
return await document_service.ingest_text(request, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@ -177,14 +191,23 @@ async def ingest_text(
@app.post("/ingest/file", response_model=Document)
async def ingest_file(
file: UploadFile,
metadata: str = Form("{}"), # JSON string of metadata
metadata: str = Form("{}"),
auth: AuthContext = Depends(verify_token),
) -> Document:
"""Ingest a file document."""
try:
metadata_dict = json.loads(metadata)
doc = await document_service.ingest_file(file, metadata_dict, auth)
return doc # TODO: Might be lighter on network to just send the document ID.
async with telemetry.track_operation(
operation_type="ingest_file",
user_id=auth.entity_id,
metadata={
"filename": file.filename,
"content_type": file.content_type,
"metadata": metadata_dict,
},
):
doc = await document_service.ingest_file(file, metadata_dict, auth)
return doc
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except json.JSONDecodeError:
@ -194,17 +217,27 @@ async def ingest_file(
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant chunks."""
return await document_service.retrieve_chunks(
request.query, auth, request.filters, request.k, request.min_score
)
async with telemetry.track_operation(
operation_type="retrieve_chunks",
user_id=auth.entity_id,
metadata=request.model_dump(),
):
return await document_service.retrieve_chunks(
request.query, auth, request.filters, request.k, request.min_score
)
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant documents."""
return await document_service.retrieve_docs(
request.query, auth, request.filters, request.k, request.min_score
)
async with telemetry.track_operation(
operation_type="retrieve_docs",
user_id=auth.entity_id,
metadata=request.model_dump(),
):
return await document_service.retrieve_docs(
request.query, auth, request.filters, request.k, request.min_score
)
@app.post("/query", response_model=CompletionResponse)
@ -212,15 +245,27 @@ async def query_completion(
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
):
"""Generate completion using relevant chunks as context."""
return await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
)
async with telemetry.track_operation(
operation_type="query",
user_id=auth.entity_id,
metadata=request.model_dump(),
) as span:
response = await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
)
if isinstance(response, dict) and "usage" in response:
usage = response["usage"]
if isinstance(usage, dict):
span.set_attribute("tokens.completion", usage.get("completion_tokens", 0))
span.set_attribute("tokens.prompt", usage.get("prompt_tokens", 0))
span.set_attribute("tokens.total", usage.get("total_tokens", 0))
return response
@app.get("/documents", response_model=List[Document])
@ -246,3 +291,53 @@ async def get_document(document_id: str, auth: AuthContext = Depends(verify_toke
except HTTPException as e:
logger.error(f"Error getting document: {e}")
raise e
# Usage tracking endpoints
@app.get("/usage/stats")
async def get_usage_stats(auth: AuthContext = Depends(verify_token)) -> Dict[str, int]:
"""Get usage statistics for the authenticated user."""
async with telemetry.track_operation(operation_type="get_usage_stats", user_id=auth.entity_id):
if not auth.permissions or "admin" not in auth.permissions:
return telemetry.get_user_usage(auth.entity_id)
return telemetry.get_user_usage(auth.entity_id)
@app.get("/usage/recent")
async def get_recent_usage(
auth: AuthContext = Depends(verify_token),
operation_type: Optional[str] = None,
since: Optional[datetime] = None,
status: Optional[str] = None,
) -> List[Dict]:
"""Get recent usage records."""
async with telemetry.track_operation(
operation_type="get_recent_usage",
user_id=auth.entity_id,
metadata={
"operation_type": operation_type,
"since": since.isoformat() if since else None,
"status": status,
},
):
if not auth.permissions or "admin" not in auth.permissions:
records = telemetry.get_recent_usage(
user_id=auth.entity_id, operation_type=operation_type, since=since, status=status
)
else:
records = telemetry.get_recent_usage(
operation_type=operation_type, since=since, status=status
)
return [
{
"timestamp": record.timestamp,
"operation_type": record.operation_type,
"tokens_used": record.tokens_used,
"user_id": record.user_id,
"duration_ms": record.duration_ms,
"status": record.status,
"metadata": record.metadata,
}
for record in records
]

298
core/services/telemetry.py Normal file
View File

@ -0,0 +1,298 @@
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
import threading
from collections import defaultdict
import time
from contextlib import asynccontextmanager
import os
import json
from pathlib import Path
from opentelemetry import trace, metrics
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.resources import Resource
from opentelemetry.trace import Status, StatusCode
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.metrics.export import (
PeriodicExportingMetricReader,
MetricExporter,
AggregationTemporality,
MetricsData,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
class FileSpanExporter:
def __init__(self, log_dir: str):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.trace_file = self.log_dir / "traces.log"
def export(self, spans):
with open(self.trace_file, "a") as f:
for span in spans:
f.write(json.dumps(self._format_span(span)) + "\n")
return True
def shutdown(self):
pass
def _format_span(self, span):
return {
"name": span.name,
"trace_id": format(span.context.trace_id, "x"),
"span_id": format(span.context.span_id, "x"),
"parent_id": format(span.parent.span_id, "x") if span.parent else None,
"start_time": span.start_time,
"end_time": span.end_time,
"attributes": dict(span.attributes),
"status": span.status.status_code.name,
}
class FileMetricExporter(MetricExporter):
"""File metric exporter for OpenTelemetry."""
def __init__(self, log_dir: str):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.metrics_file = self.log_dir / "metrics.log"
super().__init__()
def export(self, metrics_data: MetricsData, **kwargs) -> bool:
"""Export metrics data to a file.
Args:
metrics_data: The metrics data to export.
Returns:
True if the export was successful, False otherwise.
"""
try:
with open(self.metrics_file, "a") as f:
for resource_metrics in metrics_data.resource_metrics:
for scope_metrics in resource_metrics.scope_metrics:
for metric in scope_metrics.metrics:
f.write(json.dumps(self._format_metric(metric)) + "\n")
return True
except Exception:
return False
def shutdown(self, timeout_millis: float = 30_000, **kwargs) -> bool:
"""Shuts down the exporter.
Args:
timeout_millis: Time to wait for the export to complete in milliseconds.
Returns:
True if the shutdown succeeded, False otherwise.
"""
return True
def force_flush(self, timeout_millis: float = 10_000) -> bool:
"""Force flush the exporter.
Args:
timeout_millis: Time to wait for the flush to complete in milliseconds.
Returns:
True if the flush succeeded, False otherwise.
"""
return True
def _preferred_temporality(self) -> Dict:
"""Returns the preferred temporality for each instrument kind."""
return {
"counter": AggregationTemporality.CUMULATIVE,
"up_down_counter": AggregationTemporality.CUMULATIVE,
"observable_counter": AggregationTemporality.CUMULATIVE,
"observable_up_down_counter": AggregationTemporality.CUMULATIVE,
"histogram": AggregationTemporality.CUMULATIVE,
"observable_gauge": AggregationTemporality.CUMULATIVE,
}
def _format_metric(self, metric):
return {
"name": metric.name,
"description": metric.description,
"unit": metric.unit,
"data": self._format_data(metric.data),
}
def _format_data(self, data):
if hasattr(data, "data_points"):
return {
"data_points": [
{
"attributes": dict(point.attributes),
"value": point.value if hasattr(point, "value") else None,
"count": point.count if hasattr(point, "count") else None,
"sum": point.sum if hasattr(point, "sum") else None,
"timestamp": point.time_unix_nano,
}
for point in data.data_points
]
}
return {}
@dataclass
class UsageRecord:
timestamp: datetime
operation_type: str
tokens_used: int
user_id: str
duration_ms: float
status: str
metadata: Optional[Dict] = None
class TelemetryService:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self):
self._usage_records: List[UsageRecord] = []
self._user_totals = defaultdict(lambda: defaultdict(int))
self._lock = threading.Lock()
# Initialize OpenTelemetry
resource = Resource.create({"service.name": "databridge-core"})
# Create logs directory
log_dir = Path("logs/telemetry")
log_dir.mkdir(parents=True, exist_ok=True)
# Initialize tracing
tracer_provider = TracerProvider(resource=resource)
# Use file exporter for local development
if os.getenv("ENVIRONMENT", "development") == "development":
span_processor = BatchSpanProcessor(FileSpanExporter(str(log_dir)))
else:
span_processor = BatchSpanProcessor(OTLPSpanExporter())
tracer_provider.add_span_processor(span_processor)
trace.set_tracer_provider(tracer_provider)
self.tracer = trace.get_tracer(__name__)
# Initialize metrics
if os.getenv("ENVIRONMENT", "development") == "development":
metric_reader = PeriodicExportingMetricReader(
FileMetricExporter(str(log_dir)),
export_interval_millis=60000, # Export every minute
)
else:
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
self.meter = metrics.get_meter(__name__)
# Create metrics
self.operation_counter = self.meter.create_counter(
"databridge.operations",
description="Number of operations performed",
)
self.token_counter = self.meter.create_counter(
"databridge.tokens",
description="Number of tokens processed",
)
self.operation_duration = self.meter.create_histogram(
"databridge.operation.duration",
description="Duration of operations",
unit="ms",
)
@asynccontextmanager
async def track_operation(
self,
operation_type: str,
user_id: str,
tokens_used: int = 0,
metadata: Optional[Dict[str, Any]] = None,
):
"""
Context manager for tracking operations with both usage metrics and OpenTelemetry
"""
start_time = time.time()
status = "success"
current_span = trace.get_current_span()
try:
# Add operation attributes to the current span
current_span.set_attribute("operation.type", operation_type)
current_span.set_attribute("user.id", user_id)
if metadata:
for key, value in metadata.items():
current_span.set_attribute(f"metadata.{key}", str(value))
yield current_span
except Exception as e:
status = "error"
current_span.set_status(Status(StatusCode.ERROR))
current_span.record_exception(e)
raise
finally:
duration = (time.time() - start_time) * 1000 # Convert to milliseconds
# Record metrics
self.operation_counter.add(1, {"operation": operation_type, "status": status})
if tokens_used > 0:
self.token_counter.add(tokens_used, {"operation": operation_type})
self.operation_duration.record(duration, {"operation": operation_type})
# Record usage
record = UsageRecord(
timestamp=datetime.now(),
operation_type=operation_type,
tokens_used=tokens_used,
user_id=user_id,
duration_ms=duration,
status=status,
metadata=metadata,
)
with self._lock:
self._usage_records.append(record)
self._user_totals[user_id][operation_type] += tokens_used
def get_user_usage(self, user_id: str) -> Dict[str, int]:
"""Get usage statistics for a user."""
with self._lock:
return dict(self._user_totals[user_id])
def get_recent_usage(
self,
user_id: Optional[str] = None,
operation_type: Optional[str] = None,
since: Optional[datetime] = None,
status: Optional[str] = None,
) -> List[UsageRecord]:
"""Get recent usage records with optional filtering."""
with self._lock:
records = self._usage_records.copy()
# Apply filters
if user_id:
records = [r for r in records if r.user_id == user_id]
if operation_type:
records = [r for r in records if r.operation_type == operation_type]
if since:
records = [r for r in records if r.timestamp >= since]
if status:
records = [r for r in records if r.status == status]
return records

View File

@ -243,3 +243,7 @@ xlrd==2.0.1
XlsxWriter==3.2.0
xxhash==3.4.1
yarl==1.9.4
opentelemetry-api>=1.21.0
opentelemetry-sdk>=1.21.0
opentelemetry-instrumentation-fastapi>=0.42b0
opentelemetry-exporter-otlp>=1.21.0

View File

@ -205,7 +205,7 @@ class DataBridge:
"""
request = {"query": query, "filters": filters, "k": k, "min_score": min_score}
response = self._request("POST", "search/chunks", request)
response = self._request("POST", "retrieve/chunks", request)
return [ChunkResult(**r) for r in response]
def retrieve_docs(

126
shell.py Normal file
View File

@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""
DataBridge interactive CLI.
Assumes a DataBridge server is running.
Usage:
python shell.py <uri>
Example: python shell.py "http://test_user:token@localhost:8000"
This provides the exact same interface as the Python SDK:
db.ingest_text("content", metadata={...})
db.ingest_file("path/to/file")
db.query("what are the key findings?")
etc...
"""
import sys
from pathlib import Path
# Add local SDK to path before other imports
_SDK_PATH = str(Path(__file__).parent / "sdks" / "python")
if _SDK_PATH not in sys.path:
sys.path.insert(0, _SDK_PATH)
from databridge import DataBridge # noqa: E402
class DB:
def __init__(self, uri: str):
"""Initialize DataBridge with URI"""
# Convert databridge:// to http:// for localhost
if "localhost" in uri or "127.0.0.1" in uri:
uri = uri.replace("databridge://", "http://")
self.uri = uri
self._client = DataBridge(self.uri, is_local="localhost" in uri or "127.0.0.1" in uri)
def ingest_text(self, content: str, metadata: dict = None) -> dict:
"""Ingest text content into DataBridge"""
doc = self._client.ingest_text(content, metadata=metadata or {})
return doc.model_dump()
def ingest_file(self, file_path: str, metadata: dict = None, content_type: str = None) -> dict:
"""Ingest a file into DataBridge"""
file_path = Path(file_path)
doc = self._client.ingest_file(
file_path, filename=file_path.name, content_type=content_type, metadata=metadata or {}
)
return doc.model_dump()
def retrieve_chunks(
self, query: str, filters: dict = None, k: int = 4, min_score: float = 0.0
) -> list:
"""Search for relevant chunks"""
results = self._client.retrieve_chunks(
query, filters=filters or {}, k=k, min_score=min_score
)
return [r.model_dump() for r in results]
def retrieve_docs(
self, query: str, filters: dict = None, k: int = 4, min_score: float = 0.0
) -> list:
"""Retrieve relevant documents"""
results = self._client.retrieve_docs(query, filters=filters or {}, k=k, min_score=min_score)
return [r.model_dump() for r in results]
def query(
self,
query: str,
filters: dict = None,
k: int = 4,
min_score: float = 0.0,
max_tokens: int = None,
temperature: float = None,
) -> dict:
"""Generate completion using relevant chunks as context"""
response = self._client.query(
query,
filters=filters or {},
k=k,
min_score=min_score,
max_tokens=max_tokens,
temperature=temperature,
)
return response.model_dump()
def list_documents(self, skip: int = 0, limit: int = 100, filters: dict = None) -> list:
"""List accessible documents"""
docs = self._client.list_documents(skip=skip, limit=limit, filters=filters or {})
return [doc.model_dump() for doc in docs]
def get_document(self, document_id: str) -> dict:
"""Get document metadata by ID"""
doc = self._client.get_document(document_id)
return doc.model_dump()
def close(self):
"""Close the client connection"""
self._client.close()
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Error: URI argument required")
print(__doc__)
sys.exit(1)
# Create DB instance with provided URI
db = DB(sys.argv[1])
# Start an interactive Python shell with 'db' already imported
import code
import readline # Enable arrow key history
import rlcompleter # noqa: F401 # Enable tab completion
readline.parse_and_bind("tab: complete")
# Create the interactive shell
shell = code.InteractiveConsole(locals())
# Print welcome message
print("\nDataBridge CLI ready to use. The 'db' object is available with all SDK methods.")
print("Example: db.ingest_text('hello world')")
print("Type help(db) for documentation.")
# Start the shell
shell.interact(banner="")