add document retrieval endpoint

This commit is contained in:
Arnav Agrawal 2024-11-18 18:41:23 -05:00
parent 56fe944326
commit ab4fd6def2
4 changed files with 111 additions and 7 deletions

View File

@ -2,12 +2,14 @@ import uuid
from fastapi import FastAPI, HTTPException, Depends, Header, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Dict, Any, List, Optional, Annotated
from typing import Dict, Any, List, Optional, Annotated, Union
from pydantic import BaseModel, Field
import jwt
import os
from datetime import datetime, UTC
import logging
from pymongo import MongoClient
from .vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from .embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from .parser.unstructured_parser import UnstructuredAPIParser
@ -70,6 +72,7 @@ class ServiceConfig:
def _init_components(self):
"""Initialize service components"""
try:
self.database = MongoClient(os.getenv("MONGODB_URI")).get_database(os.getenv("DB_NAME", "DataBridgeTest")).get_collection(os.getenv("COLLECTION_NAME", "test"))
self.vector_store = MongoDBAtlasVectorStore(
connection_string=os.getenv("MONGODB_URI"),
database_name=os.getenv("DB_NAME", "DataBridgeTest"),
@ -126,6 +129,51 @@ service = ServiceConfig()
# Request/Response Models
class Document(BaseModel):
id: str
name: str
type: str
source: str
uploaded_at: str
size: str
redaction_level: str
stats: Dict[str, Union[int, str]] = Field(
default_factory=lambda: {
"ai_queries": 0,
"time_saved": "0h",
"last_accessed": ""
}
)
accessed_by: List[Dict[str, str]] = Field(default_factory=list)
sensitive_content: Optional[List[str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
s3_bucket: Optional[str] = None
s3_key: Optional[str] = None
@classmethod
def from_mongo(cls, data: Dict[str, Any]) -> "Document":
"""Create from MongoDB document"""
# Convert MongoDB document to Document model
return cls(
id=str(data.get("_id")),
name=data.get("system_metadata", {}).get("filename") or "Untitled",
type="document", # Default type for now
source=data.get("source"),
uploaded_at=str(data.get("_id").generation_time), # MongoDB ObjectId contains timestamp
size="N/A", # Size not stored currently
redaction_level="none", # Default redaction level
stats={
"ai_queries": 0,
"time_saved": "0h",
"last_accessed": ""
},
accessed_by=[],
metadata=data.get("metadata", {}),
s3_bucket=data.get("system_metadata", {}).get("s3_bucket"),
s3_key=data.get("system_metadata", {}).get("s3_key")
)
class IngestRequest(BaseModel):
content: str = Field(..., description="Document content (text or base64)")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
@ -174,6 +222,20 @@ async def error_handler(request: Request, call_next):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Internal server error"}
)
@app.get("/documents", response_model=List[Document])
async def get_documents(auth: AuthContext = Depends(verify_auth)) -> List[Document]:
"""Get all documents"""
filter = {
"$or": [
{"system_metadata.dev_id": auth.dev_id}, # Dev's own docs
{"permissions": {"$in": [auth.app_id]}} # Docs app has access to
]
} if auth.type == AuthType.DEVELOPER else {"system_metadata.eu_id": auth.eu_id}
documents = {doc["_id"]: doc for doc in service.database.find(filter)}.values()
return [Document.from_mongo(doc) for doc in documents]
@app.post("/ingest", response_model=IngestResponse)
@ -191,6 +253,8 @@ async def ingest_document(
# Set up system metadata.
system_metadata = SystemMetadata(doc_id=doc_id, s3_bucket=s3_bucket, s3_key=s3_key)
if request.metadata.get("filename"):
system_metadata.filename = request.metadata["filename"]
if auth.type == AuthType.DEVELOPER:
system_metadata.dev_id = auth.dev_id
system_metadata.app_id = auth.app_id

View File

@ -39,6 +39,7 @@ class SystemMetadata:
doc_id: str = None
s3_bucket: str = None
s3_key: str = None
filename: Optional[str] = None
class DocumentChunk:

View File

@ -6,6 +6,7 @@ from datetime import datetime, UTC
import asyncio
from dataclasses import dataclass
from .exceptions import AuthenticationError
from pydantic import BaseModel, Field
import logging
logger = logging.getLogger(__name__)
@ -19,6 +20,28 @@ class QueryResult:
score: Optional[float]
metadata: Dict[str, Any]
# Request/Response Models
class Document(BaseModel):
id: str
name: str
type: str
source: str
uploaded_at: str
size: str
redaction_level: str
stats: Dict[str, Union[int, str]] = Field(
default_factory=lambda: {
"ai_queries": 0,
"time_saved": "0h",
"last_accessed": ""
}
)
accessed_by: List[Dict[str, str]] = Field(default_factory=list)
sensitive_content: Optional[List[str]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
s3_bucket: Optional[str] = None
s3_key: Optional[str] = None
class DataBridge:
"""
@ -111,7 +134,8 @@ class DataBridge:
async def ingest_document(
self,
content: Union[str, bytes],
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None,
filename: Optional[str] = None
) -> str:
"""
Ingest a document into DataBridge.
@ -119,19 +143,20 @@ class DataBridge:
Args:
content: Document content (string or bytes)
metadata: Optional document metadata
content_type: Type of the content being ingested
filename: Optional filename - defaults to doc_id if not provided
Returns:
Document ID of the ingested document
"""
metadata = metadata or {}
if filename:
metadata["filename"] = filename
if isinstance(content, bytes):
import base64
content = base64.b64encode(content).decode()
metadata = metadata or {}
metadata = metadata
metadata["is_base64"] = True
metadata = metadata or {}
response = await self._make_request(
"POST",
"ingest",
@ -179,6 +204,11 @@ class DataBridge:
)
for result in response["results"]
]
async def get_documents(self) -> List[Document]:
"""Get all documents"""
response = await self._make_request("GET", "documents")
return [Document(**doc) for doc in response]
async def close(self):
"""Close the HTTP client"""

View File

@ -188,12 +188,21 @@ async def example_batch():
await db.close()
async def example_get_documents():
"""Example of getting documents"""
print("\n=== Get Documents Example ===")
db = DataBridge(create_developer_test_uri())
documents = await db.get_documents()
print(documents)
async def main():
"""Run all examples"""
try:
await example_text()
await example_pdf()
await example_batch()
await example_get_documents()
except Exception as e:
print(f"× Main error: {str(e)}")