2024-11-17 15:37:46 -05:00
|
|
|
import uuid
|
2024-11-16 01:48:15 -05:00
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Header, Request, status
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import JSONResponse
|
2024-11-18 18:41:23 -05:00
|
|
|
from typing import Dict, Any, List, Optional, Annotated, Union
|
2024-11-16 01:48:15 -05:00
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
import jwt
|
|
|
|
import os
|
|
|
|
from datetime import datetime, UTC
|
|
|
|
import logging
|
2024-11-18 18:41:23 -05:00
|
|
|
|
|
|
|
from pymongo import MongoClient
|
2024-11-16 01:48:15 -05:00
|
|
|
from .vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
|
|
|
from .embedding_model.openai_embedding_model import OpenAIEmbeddingModel
|
|
|
|
from .parser.unstructured_parser import UnstructuredAPIParser
|
|
|
|
from .planner.simple_planner import SimpleRAGPlanner
|
2024-11-17 15:37:46 -05:00
|
|
|
from .document import DocumentChunk, Permission, Source, SystemMetadata, AuthContext, AuthType
|
2024-11-18 10:45:07 -05:00
|
|
|
from .utils.aws_utils import get_s3_client, upload_from_encoded_string
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
# Configure logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
# Initialize FastAPI app
|
|
|
|
app = FastAPI(
|
|
|
|
title="DataBridge API",
|
|
|
|
description="REST API for DataBridge document ingestion and querying",
|
|
|
|
version="1.0.0"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add CORS middleware
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class DataBridgeException(HTTPException):
|
|
|
|
def __init__(self, detail: str, status_code: int = 400):
|
|
|
|
super().__init__(status_code=status_code, detail=detail)
|
|
|
|
|
|
|
|
|
|
|
|
class AuthenticationError(DataBridgeException):
|
|
|
|
def __init__(self, detail: str = "Authentication failed"):
|
|
|
|
super().__init__(detail=detail, status_code=status.HTTP_401_UNAUTHORIZED)
|
|
|
|
|
|
|
|
|
|
|
|
class ServiceConfig:
|
|
|
|
"""Service-wide configuration and component management"""
|
|
|
|
def __init__(self):
|
|
|
|
self.jwt_secret = os.getenv("JWT_SECRET_KEY")
|
|
|
|
if not self.jwt_secret:
|
|
|
|
raise ValueError("JWT_SECRET_KEY environment variable not set")
|
|
|
|
|
|
|
|
# Required environment variables
|
|
|
|
required_vars = {
|
|
|
|
"MONGODB_URI": "MongoDB connection string",
|
|
|
|
"OPENAI_API_KEY": "OpenAI API key",
|
|
|
|
"UNSTRUCTURED_API_KEY": "Unstructured API key"
|
|
|
|
}
|
|
|
|
|
|
|
|
missing = [f"{var} ({desc})" for var, desc in required_vars.items() if not os.getenv(var)]
|
|
|
|
if missing:
|
|
|
|
raise ValueError(f"Missing required environment variables: {', '.join(missing)}")
|
|
|
|
|
|
|
|
# Initialize core components
|
|
|
|
self._init_components()
|
|
|
|
|
|
|
|
def _init_components(self):
|
|
|
|
"""Initialize service components"""
|
|
|
|
try:
|
2024-11-18 18:41:23 -05:00
|
|
|
self.database = MongoClient(os.getenv("MONGODB_URI")).get_database(os.getenv("DB_NAME", "DataBridgeTest")).get_collection(os.getenv("COLLECTION_NAME", "test"))
|
2024-11-16 01:48:15 -05:00
|
|
|
self.vector_store = MongoDBAtlasVectorStore(
|
|
|
|
connection_string=os.getenv("MONGODB_URI"),
|
2024-11-16 14:37:01 -05:00
|
|
|
database_name=os.getenv("DB_NAME", "DataBridgeTest"),
|
|
|
|
collection_name=os.getenv("COLLECTION_NAME", "test")
|
2024-11-16 01:48:15 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
self.embedding_model = OpenAIEmbeddingModel(
|
|
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
|
|
model_name=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
|
|
)
|
|
|
|
|
|
|
|
self.parser = UnstructuredAPIParser(
|
|
|
|
api_key=os.getenv("UNSTRUCTURED_API_KEY"),
|
|
|
|
chunk_size=int(os.getenv("CHUNK_SIZE", "1000")),
|
|
|
|
chunk_overlap=int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
|
|
)
|
|
|
|
|
|
|
|
self.planner = SimpleRAGPlanner(
|
|
|
|
default_k=int(os.getenv("DEFAULT_K", "4"))
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(f"Failed to initialize components: {str(e)}")
|
|
|
|
|
2024-11-17 15:37:46 -05:00
|
|
|
async def verify_token(self, token: str, owner_id: str) -> AuthContext:
|
|
|
|
"""Verify JWT token and return auth context"""
|
2024-11-16 01:48:15 -05:00
|
|
|
try:
|
|
|
|
payload = jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
|
2024-11-17 15:37:46 -05:00
|
|
|
|
2024-11-16 01:48:15 -05:00
|
|
|
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
|
|
|
raise AuthenticationError("Token has expired")
|
2024-11-17 15:37:46 -05:00
|
|
|
|
|
|
|
# Check if this is a developer token
|
|
|
|
if "." in owner_id: # dev_id.app_id format
|
|
|
|
dev_id, app_id = owner_id.split(".")
|
|
|
|
return AuthContext(
|
|
|
|
type=AuthType.DEVELOPER,
|
|
|
|
dev_id=dev_id,
|
|
|
|
app_id=app_id
|
|
|
|
)
|
|
|
|
else: # User token
|
|
|
|
return AuthContext(
|
|
|
|
type=AuthType.USER,
|
|
|
|
eu_id=owner_id
|
|
|
|
)
|
|
|
|
|
2024-11-16 01:48:15 -05:00
|
|
|
except jwt.InvalidTokenError:
|
|
|
|
raise AuthenticationError("Invalid token")
|
|
|
|
except Exception as e:
|
|
|
|
raise AuthenticationError(f"Authentication failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize service
|
|
|
|
service = ServiceConfig()
|
|
|
|
|
|
|
|
|
|
|
|
# Request/Response Models
|
2024-11-18 18:41:23 -05:00
|
|
|
class Document(BaseModel):
|
|
|
|
id: str
|
|
|
|
name: str
|
|
|
|
type: str
|
|
|
|
source: str
|
|
|
|
uploaded_at: str
|
|
|
|
size: str
|
|
|
|
redaction_level: str
|
|
|
|
stats: Dict[str, Union[int, str]] = Field(
|
|
|
|
default_factory=lambda: {
|
|
|
|
"ai_queries": 0,
|
|
|
|
"time_saved": "0h",
|
|
|
|
"last_accessed": ""
|
|
|
|
}
|
|
|
|
)
|
|
|
|
accessed_by: List[Dict[str, str]] = Field(default_factory=list)
|
|
|
|
sensitive_content: Optional[List[str]] = None
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
s3_bucket: Optional[str] = None
|
|
|
|
s3_key: Optional[str] = None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_mongo(cls, data: Dict[str, Any]) -> "Document":
|
|
|
|
"""Create from MongoDB document"""
|
|
|
|
# Convert MongoDB document to Document model
|
|
|
|
return cls(
|
|
|
|
id=str(data.get("_id")),
|
|
|
|
name=data.get("system_metadata", {}).get("filename") or "Untitled",
|
|
|
|
type="document", # Default type for now
|
|
|
|
source=data.get("source"),
|
|
|
|
uploaded_at=str(data.get("_id").generation_time), # MongoDB ObjectId contains timestamp
|
|
|
|
size="N/A", # Size not stored currently
|
|
|
|
redaction_level="none", # Default redaction level
|
|
|
|
stats={
|
|
|
|
"ai_queries": 0,
|
|
|
|
"time_saved": "0h",
|
|
|
|
"last_accessed": ""
|
|
|
|
},
|
|
|
|
accessed_by=[],
|
|
|
|
metadata=data.get("metadata", {}),
|
|
|
|
s3_bucket=data.get("system_metadata", {}).get("s3_bucket"),
|
|
|
|
s3_key=data.get("system_metadata", {}).get("s3_key")
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-11-16 01:48:15 -05:00
|
|
|
class IngestRequest(BaseModel):
|
|
|
|
content: str = Field(..., description="Document content (text or base64)")
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
2024-11-17 15:37:46 -05:00
|
|
|
eu_id: Optional[str] = Field(None, description="End user ID when developer ingests for user")
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
|
|
query: str = Field(..., description="Query string")
|
|
|
|
k: Optional[int] = Field(default=4, description="Number of results to return")
|
|
|
|
filters: Optional[Dict[str, Any]] = Field(default=None,
|
|
|
|
description="Optional metadata filters")
|
|
|
|
|
|
|
|
|
|
|
|
class IngestResponse(BaseModel):
|
|
|
|
document_id: str = Field(..., description="Ingested document ID")
|
|
|
|
message: str = Field(default="Document ingested successfully")
|
|
|
|
|
|
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
|
|
results: List[Dict[str, Any]] = Field(..., description="Query results")
|
|
|
|
total_results: int = Field(..., description="Total number of results")
|
|
|
|
|
|
|
|
|
|
|
|
# Authentication dependency
|
|
|
|
async def verify_auth(
|
|
|
|
owner_id: Annotated[str, Header(alias="X-Owner-ID")],
|
|
|
|
auth_token: Annotated[str, Header(alias="X-Auth-Token")]
|
|
|
|
) -> str:
|
|
|
|
"""Verify authentication headers"""
|
2024-11-17 15:37:46 -05:00
|
|
|
return await service.verify_token(auth_token, owner_id)
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
|
|
|
|
# Error handler middleware
|
|
|
|
@app.middleware("http")
|
|
|
|
async def error_handler(request: Request, call_next):
|
|
|
|
try:
|
|
|
|
return await call_next(request)
|
|
|
|
except DataBridgeException as e:
|
|
|
|
return JSONResponse(
|
|
|
|
status_code=e.status_code,
|
|
|
|
content={"error": e.detail}
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception("Unexpected error")
|
|
|
|
return JSONResponse(
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
content={"error": "Internal server error"}
|
|
|
|
)
|
2024-11-18 18:41:23 -05:00
|
|
|
|
|
|
|
|
|
|
|
@app.get("/documents", response_model=List[Document])
|
|
|
|
async def get_documents(auth: AuthContext = Depends(verify_auth)) -> List[Document]:
|
|
|
|
"""Get all documents"""
|
|
|
|
filter = {
|
|
|
|
"$or": [
|
|
|
|
{"system_metadata.dev_id": auth.dev_id}, # Dev's own docs
|
|
|
|
{"permissions": {"$in": [auth.app_id]}} # Docs app has access to
|
|
|
|
]
|
|
|
|
} if auth.type == AuthType.DEVELOPER else {"system_metadata.eu_id": auth.eu_id}
|
|
|
|
|
|
|
|
documents = {doc["_id"]: doc for doc in service.database.find(filter)}.values()
|
|
|
|
return [Document.from_mongo(doc) for doc in documents]
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
|
|
|
|
@app.post("/ingest", response_model=IngestResponse)
|
|
|
|
async def ingest_document(
|
|
|
|
request: IngestRequest,
|
2024-11-17 15:37:46 -05:00
|
|
|
auth: AuthContext = Depends(verify_auth)
|
2024-11-16 01:48:15 -05:00
|
|
|
) -> IngestResponse:
|
2024-11-17 15:37:46 -05:00
|
|
|
"""Ingest a document into DataBridge."""
|
|
|
|
logger.info(f"Ingesting document for {auth.type}")
|
|
|
|
|
|
|
|
# Generate document ID for all chunks.
|
|
|
|
doc_id = str(uuid.uuid4())
|
2024-11-18 10:45:07 -05:00
|
|
|
s3_client = get_s3_client()
|
|
|
|
s3_bucket, s3_key = upload_from_encoded_string(s3_client, request.content, doc_id)
|
2024-11-17 15:37:46 -05:00
|
|
|
|
|
|
|
# Set up system metadata.
|
2024-11-18 10:45:07 -05:00
|
|
|
system_metadata = SystemMetadata(doc_id=doc_id, s3_bucket=s3_bucket, s3_key=s3_key)
|
2024-11-18 18:41:23 -05:00
|
|
|
if request.metadata.get("filename"):
|
|
|
|
system_metadata.filename = request.metadata["filename"]
|
2024-11-17 15:37:46 -05:00
|
|
|
if auth.type == AuthType.DEVELOPER:
|
|
|
|
system_metadata.dev_id = auth.dev_id
|
|
|
|
system_metadata.app_id = auth.app_id
|
|
|
|
if request.eu_id:
|
|
|
|
system_metadata.eu_id = request.eu_id
|
|
|
|
else:
|
|
|
|
system_metadata.eu_id = auth.eu_id
|
|
|
|
|
|
|
|
# Parse into chunks.
|
2024-11-16 01:48:15 -05:00
|
|
|
chunk_texts = service.parser.parse(request.content, request.metadata)
|
2024-11-16 14:37:01 -05:00
|
|
|
embeddings = await service.embedding_model.embed_for_ingestion(chunk_texts)
|
2024-11-17 15:37:46 -05:00
|
|
|
|
|
|
|
# Create chunks.
|
2024-11-16 01:48:15 -05:00
|
|
|
chunks = []
|
2024-11-17 15:37:46 -05:00
|
|
|
for text, embedding in zip(chunk_texts, embeddings):
|
|
|
|
# Set source and permissions based on context.
|
|
|
|
if auth.type == AuthType.DEVELOPER:
|
|
|
|
source = Source.APP
|
|
|
|
permissions = {auth.app_id: {Permission.READ, Permission.WRITE, Permission.DELETE}} if request.eu_id else {}
|
|
|
|
else:
|
|
|
|
source = Source.SELF_UPLOADED
|
|
|
|
permissions = {}
|
|
|
|
|
|
|
|
chunk = DocumentChunk(
|
|
|
|
content=text,
|
|
|
|
embedding=embedding,
|
|
|
|
metadata=request.metadata,
|
|
|
|
system_metadata=system_metadata,
|
|
|
|
source=source,
|
|
|
|
permissions=permissions
|
|
|
|
)
|
2024-11-16 01:48:15 -05:00
|
|
|
chunks.append(chunk)
|
|
|
|
|
2024-11-17 15:37:46 -05:00
|
|
|
# Store in vector store.
|
2024-11-16 01:48:15 -05:00
|
|
|
if not service.vector_store.store_embeddings(chunks):
|
|
|
|
raise DataBridgeException(
|
|
|
|
"Failed to store embeddings",
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
|
|
)
|
|
|
|
|
2024-11-17 15:37:46 -05:00
|
|
|
return IngestResponse(document_id=doc_id)
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
|
|
|
|
@app.post("/query", response_model=QueryResponse)
|
|
|
|
async def query_documents(
|
|
|
|
request: QueryRequest,
|
2024-11-17 15:37:46 -05:00
|
|
|
auth: AuthContext = Depends(verify_auth)
|
2024-11-16 01:48:15 -05:00
|
|
|
) -> QueryResponse:
|
|
|
|
"""
|
|
|
|
Query documents in DataBridge.
|
|
|
|
All configuration and credentials are handled server-side.
|
|
|
|
"""
|
2024-11-17 15:37:46 -05:00
|
|
|
logger.info(f"Processing query for owner {auth.type}")
|
2024-11-16 01:48:15 -05:00
|
|
|
# Create plan
|
|
|
|
plan = service.planner.plan_retrieval(request.query, k=request.k)
|
2024-11-16 14:37:01 -05:00
|
|
|
query_embedding = await service.embedding_model.embed_for_query(request.query)
|
2024-11-16 01:48:15 -05:00
|
|
|
|
|
|
|
# Query vector store
|
|
|
|
chunks = service.vector_store.query_similar(
|
|
|
|
query_embedding,
|
|
|
|
k=plan["k"],
|
2024-11-17 15:37:46 -05:00
|
|
|
auth=auth,
|
2024-11-16 01:48:15 -05:00
|
|
|
filters=request.filters
|
|
|
|
)
|
|
|
|
|
|
|
|
results = [
|
|
|
|
{
|
|
|
|
"content": chunk.content,
|
2024-11-17 15:37:46 -05:00
|
|
|
"doc_id": chunk.system_metadata.doc_id,
|
|
|
|
"score": chunk.score,
|
|
|
|
"metadata": chunk.metadata
|
2024-11-16 01:48:15 -05:00
|
|
|
}
|
|
|
|
for chunk in chunks
|
|
|
|
]
|
|
|
|
|
|
|
|
return QueryResponse(
|
|
|
|
results=results,
|
|
|
|
total_results=len(results)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Health check endpoint
|
|
|
|
@app.get("/health")
|
|
|
|
async def health_check():
|
|
|
|
"""Check service health"""
|
|
|
|
try:
|
|
|
|
# Verify MongoDB connection
|
|
|
|
service.vector_store.collection.find_one({})
|
|
|
|
return {"status": "healthy"}
|
|
|
|
except Exception as e:
|
|
|
|
raise DataBridgeException(
|
|
|
|
f"Service unhealthy: {str(e)}",
|
|
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Startup and shutdown events
|
|
|
|
@app.on_event("startup")
|
|
|
|
async def startup_event():
|
|
|
|
"""Verify all connections on startup"""
|
|
|
|
logger.info("Starting DataBridge service")
|
|
|
|
await health_check()
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
|
async def shutdown_event():
|
|
|
|
"""Cleanup on shutdown"""
|
|
|
|
logger.info("Shutting down DataBridge service")
|