1393 lines
50 KiB
Python

import asyncio
import json
from datetime import datetime, UTC, timedelta
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.limits_utils import check_and_increment_limits
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, UpdateGraphRequest, BatchIngestResponse
from core.models.completion import ChunkSource, CompletionResponse
from core.models.documents import Document, DocumentResult, ChunkResult
from core.models.graph import Graph
from core.models.auth import AuthContext, EntityType
from core.models.prompts import validate_prompt_overrides_with_http_exception
from core.parser.morphik_parser import MorphikParser
from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.config import get_settings
from core.database.mongo_database import MongoDatabase
from core.database.postgres_database import PostgresDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.vector_store.multi_vector_store import MultiVectorStore
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.storage.s3_storage import S3Storage
from core.storage.local_storage import LocalStorage
from core.reranker.flag_reranker import FlagReranker
from core.cache.llama_cache_factory import LlamaCacheFactory
import tomli
# Initialize FastAPI app
app = FastAPI(title="Morphik API")
logger = logging.getLogger(__name__)
# Add health check endpoints
@app.get("/health")
async def health_check():
"""Basic health check endpoint."""
return {"status": "healthy"}
@app.get("/health/ready")
async def readiness_check():
"""Readiness check that verifies the application is initialized."""
return {
"status": "ready",
"components": {
"database": settings.DATABASE_PROVIDER,
"vector_store": settings.VECTOR_STORE_PROVIDER,
"embedding": settings.EMBEDDING_PROVIDER,
"completion": settings.COMPLETION_PROVIDER,
"storage": settings.STORAGE_PROVIDER,
},
}
# Initialize telemetry
telemetry = TelemetryService()
# Add OpenTelemetry instrumentation
FastAPIInstrumentor.instrument_app(app)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize service
settings = get_settings()
# Initialize database
match settings.DATABASE_PROVIDER:
case "postgres":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
case "mongodb":
if not settings.MONGODB_URI:
raise ValueError("MongoDB URI is required for MongoDB database")
database = MongoDatabase(
uri=settings.MONGODB_URI,
db_name=settings.DATABRIDGE_DB,
collection_name=settings.DOCUMENTS_COLLECTION,
)
case _:
raise ValueError(f"Unsupported database provider: {settings.DATABASE_PROVIDER}")
@app.on_event("startup")
async def initialize_database():
"""Initialize database tables and indexes on application startup."""
logger.info("Initializing database...")
success = await database.initialize()
if success:
logger.info("Database initialization successful")
else:
logger.error("Database initialization failed")
# We don't raise an exception here to allow the app to continue starting
# even if there are initialization errors
@app.on_event("startup")
async def initialize_vector_store():
"""Initialize vector store tables and indexes on application startup."""
# First initialize the primary vector store (PGVectorStore if using pgvector)
logger.info("Initializing primary vector store...")
if hasattr(vector_store, 'initialize'):
success = await vector_store.initialize()
if success:
logger.info("Primary vector store initialization successful")
else:
logger.error("Primary vector store initialization failed")
else:
logger.warning("Primary vector store does not have an initialize method")
# Then initialize the multivector store if enabled
if settings.ENABLE_COLPALI and colpali_vector_store:
logger.info("Initializing multivector store...")
# Handle both synchronous and asynchronous initialize methods
if hasattr(colpali_vector_store.initialize, '__awaitable__'):
success = await colpali_vector_store.initialize()
else:
success = colpali_vector_store.initialize()
if success:
logger.info("Multivector store initialization successful")
else:
logger.error("Multivector store initialization failed")
# Initialize vector store
match settings.VECTOR_STORE_PROVIDER:
case "mongodb":
vector_store = MongoDBAtlasVectorStore(
uri=settings.MONGODB_URI,
database_name=settings.DATABRIDGE_DB,
collection_name=settings.CHUNKS_COLLECTION,
index_name=settings.VECTOR_INDEX_NAME,
)
case "pgvector":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
from core.vector_store.pgvector_store import PGVectorStore
vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
case _:
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
# Initialize storage
match settings.STORAGE_PROVIDER:
case "local":
storage = LocalStorage(storage_path=settings.STORAGE_PATH)
case "aws-s3":
if not settings.AWS_ACCESS_KEY or not settings.AWS_SECRET_ACCESS_KEY:
raise ValueError("AWS credentials are required for S3 storage")
storage = S3Storage(
aws_access_key=settings.AWS_ACCESS_KEY,
aws_secret_key=settings.AWS_SECRET_ACCESS_KEY,
region_name=settings.AWS_REGION,
default_bucket=settings.S3_BUCKET,
)
case _:
raise ValueError(f"Unsupported storage provider: {settings.STORAGE_PROVIDER}")
# Initialize parser
parser = MorphikParser(
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
use_unstructured_api=settings.USE_UNSTRUCTURED_API,
unstructured_api_key=settings.UNSTRUCTURED_API_KEY,
assemblyai_api_key=settings.ASSEMBLYAI_API_KEY,
anthropic_api_key=settings.ANTHROPIC_API_KEY,
use_contextual_chunking=settings.USE_CONTEXTUAL_CHUNKING,
)
# Initialize embedding model
# Import here to avoid circular imports
from core.embedding.litellm_embedding import LiteLLMEmbeddingModel
# Create a LiteLLM model using the registered model config
embedding_model = LiteLLMEmbeddingModel(
model_key=settings.EMBEDDING_MODEL,
)
logger.info(f"Initialized LiteLLM embedding model with model key: {settings.EMBEDDING_MODEL}")
# Initialize completion model
# Import here to avoid circular imports
from core.completion.litellm_completion import LiteLLMCompletionModel
# Create a LiteLLM model using the registered model config
completion_model = LiteLLMCompletionModel(
model_key=settings.COMPLETION_MODEL,
)
logger.info(f"Initialized LiteLLM completion model with model key: {settings.COMPLETION_MODEL}")
# Initialize reranker
reranker = None
if settings.USE_RERANKING:
match settings.RERANKER_PROVIDER:
case "flag":
reranker = FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
)
case _:
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
# Initialize cache factory
cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH))
# Initialize ColPali embedding model if enabled
colpali_embedding_model = ColpaliEmbeddingModel() if settings.ENABLE_COLPALI else None
colpali_vector_store = (
MultiVectorStore(uri=settings.POSTGRES_URI) if settings.ENABLE_COLPALI else None
)
# Initialize document service with configured components
document_service = DocumentService(
storage=storage,
database=database,
vector_store=vector_store,
embedding_model=embedding_model,
completion_model=completion_model,
parser=parser,
reranker=reranker,
cache_factory=cache_factory,
enable_colpali=settings.ENABLE_COLPALI,
colpali_embedding_model=colpali_embedding_model,
colpali_vector_store=colpali_vector_store,
)
async def verify_token(authorization: str = Header(None)) -> AuthContext:
"""Verify JWT Bearer token or return dev context if dev_mode is enabled."""
# Check if dev mode is enabled
if settings.dev_mode:
return AuthContext(
entity_type=EntityType(settings.dev_entity_type),
entity_id=settings.dev_entity_id,
permissions=set(settings.dev_permissions),
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
)
# Normal token verification flow
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
try:
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired")
# Support both "type" and "entity_type" fields for compatibility
entity_type_field = payload.get("type") or payload.get("entity_type")
if not entity_type_field:
raise HTTPException(status_code=401, detail="Missing entity type in token")
return AuthContext(
entity_type=EntityType(entity_type_field),
entity_id=payload["entity_id"],
app_id=payload.get("app_id"),
permissions=set(payload.get("permissions", ["read"])),
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
@app.post("/ingest/text", response_model=Document)
async def ingest_text(
request: IngestTextRequest,
auth: AuthContext = Depends(verify_token),
) -> Document:
"""
Ingest a text document.
Args:
request: IngestTextRequest containing:
- content: Text content to ingest
- filename: Optional filename to help determine content type
- metadata: Optional metadata dictionary
- rules: Optional list of rules. Each rule should be either:
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
auth: Authentication context
Returns:
Document: Metadata of ingested document
"""
try:
async with telemetry.track_operation(
operation_type="ingest_text",
user_id=auth.entity_id,
tokens_used=len(request.content.split()), # Approximate token count
metadata={
"metadata": request.metadata,
"rules": request.rules,
"use_colpali": request.use_colpali,
},
):
return await document_service.ingest_text(
content=request.content,
filename=request.filename,
metadata=request.metadata,
rules=request.rules,
use_colpali=request.use_colpali,
auth=auth,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/ingest/file", response_model=Document)
async def ingest_file(
file: UploadFile,
metadata: str = Form("{}"),
rules: str = Form("[]"),
auth: AuthContext = Depends(verify_token),
use_colpali: Optional[bool] = None,
) -> Document:
"""
Ingest a file document.
Args:
file: File to ingest
metadata: JSON string of metadata
rules: JSON string of rules list. Each rule should be either:
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
auth: Authentication context
Returns:
Document: Metadata of ingested document
"""
try:
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
use_colpali = bool(use_colpali)
async with telemetry.track_operation(
operation_type="ingest_file",
user_id=auth.entity_id,
metadata={
"filename": file.filename,
"content_type": file.content_type,
"metadata": metadata_dict,
"rules": rules_list,
"use_colpali": use_colpali,
},
):
logger.debug(f"API: Ingesting file with use_colpali: {use_colpali}")
return await document_service.ingest_file(
file=file,
metadata=metadata_dict,
auth=auth,
rules=rules_list,
use_colpali=use_colpali,
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/ingest/files", response_model=BatchIngestResponse)
async def batch_ingest_files(
files: List[UploadFile] = File(...),
metadata: str = Form("{}"),
rules: str = Form("[]"),
use_colpali: Optional[bool] = Form(None),
parallel: bool = Form(True),
auth: AuthContext = Depends(verify_token),
) -> BatchIngestResponse:
"""
Batch ingest multiple files.
Args:
files: List of files to ingest
metadata: JSON string of metadata (either a single dict or list of dicts)
rules: JSON string of rules list. Can be either:
- A single list of rules to apply to all files
- A list of rule lists, one per file
use_colpali: Whether to use ColPali-style embedding
parallel: Whether to process files in parallel
auth: Authentication context
Returns:
BatchIngestResponse containing:
- documents: List of successfully ingested documents
- errors: List of errors encountered during ingestion
"""
if not files:
raise HTTPException(
status_code=400,
detail="No files provided for batch ingestion"
)
try:
metadata_value = json.loads(metadata)
rules_list = json.loads(rules)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
# Validate metadata if it's a list
if isinstance(metadata_value, list) and len(metadata_value) != len(files):
raise HTTPException(
status_code=400,
detail=f"Number of metadata items ({len(metadata_value)}) must match number of files ({len(files)})"
)
# Validate rules if it's a list of lists
if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list):
if len(rules_list) != len(files):
raise HTTPException(
status_code=400,
detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})"
)
documents = []
errors = []
async with telemetry.track_operation(
operation_type="batch_ingest",
user_id=auth.entity_id,
metadata={
"file_count": len(files),
"metadata_type": "list" if isinstance(metadata_value, list) else "single",
"rules_type": "per_file" if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else "shared",
},
):
if parallel:
tasks = []
for i, file in enumerate(files):
metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value
file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list
task = document_service.ingest_file(
file=file,
metadata=metadata_item,
auth=auth,
rules=file_rules,
use_colpali=use_colpali
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
errors.append({
"filename": files[i].filename,
"error": str(result)
})
else:
documents.append(result)
else:
for i, file in enumerate(files):
try:
metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value
file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list
doc = await document_service.ingest_file(
file=file,
metadata=metadata_item,
auth=auth,
rules=file_rules,
use_colpali=use_colpali
)
documents.append(doc)
except Exception as e:
errors.append({
"filename": file.filename,
"error": str(e)
})
return BatchIngestResponse(documents=documents, errors=errors)
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant chunks."""
try:
async with telemetry.track_operation(
operation_type="retrieve_chunks",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"use_reranking": request.use_reranking,
"use_colpali": request.use_colpali,
},
):
return await document_service.retrieve_chunks(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
request.use_colpali,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant documents."""
try:
async with telemetry.track_operation(
operation_type="retrieve_docs",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"use_reranking": request.use_reranking,
"use_colpali": request.use_colpali,
},
):
return await document_service.retrieve_docs(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
request.use_colpali,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/documents", response_model=List[Document])
async def batch_get_documents(document_ids: List[str], auth: AuthContext = Depends(verify_token)):
"""Retrieve multiple documents by their IDs in a single batch operation."""
try:
async with telemetry.track_operation(
operation_type="batch_get_documents",
user_id=auth.entity_id,
metadata={
"document_count": len(document_ids),
},
):
return await document_service.batch_retrieve_documents(document_ids, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/chunks", response_model=List[ChunkResult])
async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Depends(verify_token)):
"""Retrieve specific chunks by their document ID and chunk number in a single batch operation."""
try:
async with telemetry.track_operation(
operation_type="batch_get_chunks",
user_id=auth.entity_id,
metadata={
"chunk_count": len(chunk_ids),
},
):
return await document_service.batch_retrieve_chunks(chunk_ids, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/query", response_model=CompletionResponse)
async def query_completion(
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
):
"""Generate completion using relevant chunks as context.
When graph_name is provided, the query will leverage the knowledge graph
to enhance retrieval by finding relevant entities and their connected documents.
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="query")
# Check query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "query", 1)
async with telemetry.track_operation(
operation_type="query",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"use_reranking": request.use_reranking,
"use_colpali": request.use_colpali,
"graph_name": request.graph_name,
"hop_depth": request.hop_depth,
"include_paths": request.include_paths,
},
):
return await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
request.use_reranking,
request.use_colpali,
request.graph_name,
request.hop_depth,
request.include_paths,
request.prompt_overrides,
)
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="query", error=e)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents", response_model=List[Document])
async def list_documents(
auth: AuthContext = Depends(verify_token),
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
):
"""List accessible documents."""
return await document_service.db.get_documents(auth, skip, limit, filters)
@app.get("/documents/{document_id}", response_model=Document)
async def get_document(document_id: str, auth: AuthContext = Depends(verify_token)):
"""Get document by ID."""
try:
doc = await document_service.db.get_document(document_id, auth)
logger.debug(f"Found document: {doc}")
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
return doc
except HTTPException as e:
logger.error(f"Error getting document: {e}")
raise e
@app.delete("/documents/{document_id}")
async def delete_document(document_id: str, auth: AuthContext = Depends(verify_token)):
"""
Delete a document and all associated data.
This endpoint deletes a document and all its associated data, including:
- Document metadata
- Document content in storage
- Document chunks and embeddings in vector store
Args:
document_id: ID of the document to delete
auth: Authentication context (must have write access to the document)
Returns:
Deletion status
"""
try:
async with telemetry.track_operation(
operation_type="delete_document",
user_id=auth.entity_id,
metadata={"document_id": document_id},
):
success = await document_service.delete_document(document_id, auth)
if not success:
raise HTTPException(status_code=404, detail="Document not found or delete failed")
return {"status": "success", "message": f"Document {document_id} deleted successfully"}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/documents/filename/{filename}", response_model=Document)
async def get_document_by_filename(filename: str, auth: AuthContext = Depends(verify_token)):
"""Get document by filename."""
try:
doc = await document_service.db.get_document_by_filename(filename, auth)
logger.debug(f"Found document by filename: {doc}")
if not doc:
raise HTTPException(status_code=404, detail=f"Document with filename '{filename}' not found")
return doc
except HTTPException as e:
logger.error(f"Error getting document by filename: {e}")
raise e
@app.post("/documents/{document_id}/update_text", response_model=Document)
async def update_document_text(
document_id: str,
request: IngestTextRequest,
update_strategy: str = "add",
auth: AuthContext = Depends(verify_token)
):
"""
Update a document with new text content using the specified strategy.
Args:
document_id: ID of the document to update
request: Text content and metadata for the update
update_strategy: Strategy for updating the document (default: 'add')
Returns:
Document: Updated document metadata
"""
try:
async with telemetry.track_operation(
operation_type="update_document_text",
user_id=auth.entity_id,
metadata={
"document_id": document_id,
"update_strategy": update_strategy,
"use_colpali": request.use_colpali,
"has_filename": request.filename is not None,
},
):
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=request.content,
file=None,
filename=request.filename,
metadata=request.metadata,
rules=request.rules,
update_strategy=update_strategy,
use_colpali=request.use_colpali,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_file", response_model=Document)
async def update_document_file(
document_id: str,
file: UploadFile,
metadata: str = Form("{}"),
rules: str = Form("[]"),
update_strategy: str = Form("add"),
use_colpali: Optional[bool] = None,
auth: AuthContext = Depends(verify_token)
):
"""
Update a document with content from a file using the specified strategy.
Args:
document_id: ID of the document to update
file: File to add to the document
metadata: JSON string of metadata to merge with existing metadata
rules: JSON string of rules to apply to the content
update_strategy: Strategy for updating the document (default: 'add')
use_colpali: Whether to use multi-vector embedding
Returns:
Document: Updated document metadata
"""
try:
metadata_dict = json.loads(metadata)
rules_list = json.loads(rules)
async with telemetry.track_operation(
operation_type="update_document_file",
user_id=auth.entity_id,
metadata={
"document_id": document_id,
"filename": file.filename,
"content_type": file.content_type,
"update_strategy": update_strategy,
"use_colpali": use_colpali,
},
):
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=None,
file=file,
filename=file.filename,
metadata=metadata_dict,
rules=rules_list,
update_strategy=update_strategy,
use_colpali=use_colpali,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_metadata", response_model=Document)
async def update_document_metadata(
document_id: str,
metadata: Dict[str, Any],
auth: AuthContext = Depends(verify_token)
):
"""
Update only a document's metadata.
Args:
document_id: ID of the document to update
metadata: New metadata to merge with existing metadata
Returns:
Document: Updated document metadata
"""
try:
async with telemetry.track_operation(
operation_type="update_document_metadata",
user_id=auth.entity_id,
metadata={
"document_id": document_id,
},
):
doc = await document_service.update_document(
document_id=document_id,
auth=auth,
content=None,
file=None,
filename=None,
metadata=metadata,
rules=[],
update_strategy="add",
use_colpali=None,
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found or update failed")
return doc
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
# Usage tracking endpoints
@app.get("/usage/stats")
async def get_usage_stats(auth: AuthContext = Depends(verify_token)) -> Dict[str, int]:
"""Get usage statistics for the authenticated user."""
async with telemetry.track_operation(operation_type="get_usage_stats", user_id=auth.entity_id):
if not auth.permissions or "admin" not in auth.permissions:
return telemetry.get_user_usage(auth.entity_id)
return telemetry.get_user_usage(auth.entity_id)
@app.get("/usage/recent")
async def get_recent_usage(
auth: AuthContext = Depends(verify_token),
operation_type: Optional[str] = None,
since: Optional[datetime] = None,
status: Optional[str] = None,
) -> List[Dict]:
"""Get recent usage records."""
async with telemetry.track_operation(
operation_type="get_recent_usage",
user_id=auth.entity_id,
metadata={
"operation_type": operation_type,
"since": since.isoformat() if since else None,
"status": status,
},
):
if not auth.permissions or "admin" not in auth.permissions:
records = telemetry.get_recent_usage(
user_id=auth.entity_id, operation_type=operation_type, since=since, status=status
)
else:
records = telemetry.get_recent_usage(
operation_type=operation_type, since=since, status=status
)
return [
{
"timestamp": record.timestamp,
"operation_type": record.operation_type,
"tokens_used": record.tokens_used,
"user_id": record.user_id,
"duration_ms": record.duration_ms,
"status": record.status,
"metadata": record.metadata,
}
for record in records
]
# Cache endpoints
@app.post("/cache/create")
async def create_cache(
name: str,
model: str,
gguf_file: str,
filters: Optional[Dict[str, Any]] = None,
docs: Optional[List[str]] = None,
auth: AuthContext = Depends(verify_token),
) -> Dict[str, Any]:
"""Create a new cache with specified configuration."""
try:
# Check cache creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache", 1)
async with telemetry.track_operation(
operation_type="create_cache",
user_id=auth.entity_id,
metadata={
"name": name,
"model": model,
"gguf_file": gguf_file,
"filters": filters,
"docs": docs,
},
):
filter_docs = set(await document_service.db.get_documents(auth, filters=filters))
additional_docs = (
{
await document_service.db.get_document(document_id=doc_id, auth=auth)
for doc_id in docs
}
if docs
else set()
)
docs_to_add = list(filter_docs.union(additional_docs))
if not docs_to_add:
raise HTTPException(status_code=400, detail="No documents to add to cache")
response = await document_service.create_cache(
name, model, gguf_file, docs_to_add, filters
)
return response
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/cache/{name}")
async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]:
"""Get cache configuration by name."""
try:
async with telemetry.track_operation(
operation_type="get_cache",
user_id=auth.entity_id,
metadata={"name": name},
):
exists = await document_service.load_cache(name)
return {"exists": exists}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/update")
async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]:
"""Update cache with new documents matching its filter."""
try:
async with telemetry.track_operation(
operation_type="update_cache",
user_id=auth.entity_id,
metadata={"name": name},
):
if name not in document_service.active_caches:
exists = await document_service.load_cache(name)
if not exists:
raise HTTPException(status_code=404, detail=f"Cache '{name}' not found")
cache = document_service.active_caches[name]
docs = await document_service.db.get_documents(auth, filters=cache.filters)
docs_to_add = [doc for doc in docs if doc.id not in cache.docs]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/add_docs")
async def add_docs_to_cache(
name: str, docs: List[str], auth: AuthContext = Depends(verify_token)
) -> Dict[str, bool]:
"""Add specific documents to the cache."""
try:
async with telemetry.track_operation(
operation_type="add_docs_to_cache",
user_id=auth.entity_id,
metadata={"name": name, "docs": docs},
):
cache = document_service.active_caches[name]
docs_to_add = [
await document_service.db.get_document(doc_id, auth)
for doc_id in docs
if doc_id not in cache.docs
]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/query")
async def query_cache(
name: str,
query: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
auth: AuthContext = Depends(verify_token),
) -> CompletionResponse:
"""Query the cache with a prompt."""
try:
# Check cache query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache_query", 1)
async with telemetry.track_operation(
operation_type="query_cache",
user_id=auth.entity_id,
metadata={
"name": name,
"query": query,
"max_tokens": max_tokens,
"temperature": temperature,
},
):
cache = document_service.active_caches[name]
print(f"Cache state: {cache.state.n_tokens}", file=sys.stderr)
return cache.query(query) # , max_tokens, temperature)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/graph/create", response_model=Graph)
async def create_graph(
request: CreateGraphRequest,
auth: AuthContext = Depends(verify_token),
) -> Graph:
"""
Create a graph from documents.
This endpoint extracts entities and relationships from documents
matching the specified filters or document IDs and creates a graph.
Args:
request: CreateGraphRequest containing:
- name: Name of the graph to create
- filters: Optional metadata filters to determine which documents to include
- documents: Optional list of specific document IDs to include
auth: Authentication context
Returns:
Graph: The created graph object
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph")
# Check graph creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "graph", 1)
async with telemetry.track_operation(
operation_type="create_graph",
user_id=auth.entity_id,
metadata={
"name": request.name,
"filters": request.filters,
"documents": request.documents,
},
):
return await document_service.create_graph(
name=request.name,
auth=auth,
filters=request.filters,
documents=request.documents,
prompt_overrides=request.prompt_overrides,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="graph", error=e)
@app.get("/graph/{name}", response_model=Graph)
async def get_graph(
name: str,
auth: AuthContext = Depends(verify_token),
) -> Graph:
"""
Get a graph by name.
This endpoint retrieves a graph by its name if the user has access to it.
Args:
name: Name of the graph to retrieve
auth: Authentication context
Returns:
Graph: The requested graph object
"""
try:
async with telemetry.track_operation(
operation_type="get_graph",
user_id=auth.entity_id,
metadata={"name": name},
):
graph = await document_service.db.get_graph(name, auth)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph '{name}' not found")
return graph
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/graphs", response_model=List[Graph])
async def list_graphs(
auth: AuthContext = Depends(verify_token),
) -> List[Graph]:
"""
List all graphs the user has access to.
This endpoint retrieves all graphs the user has access to.
Args:
auth: Authentication context
Returns:
List[Graph]: List of graph objects
"""
try:
async with telemetry.track_operation(
operation_type="list_graphs",
user_id=auth.entity_id,
):
return await document_service.db.list_graphs(auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/graph/{name}/update", response_model=Graph)
async def update_graph(
name: str,
request: UpdateGraphRequest,
auth: AuthContext = Depends(verify_token),
) -> Graph:
"""
Update an existing graph with new documents.
This endpoint processes additional documents based on the original graph filters
and/or new filters/document IDs, extracts entities and relationships, and
updates the graph with new information.
Args:
name: Name of the graph to update
request: UpdateGraphRequest containing:
- additional_filters: Optional additional metadata filters to determine which new documents to include
- additional_documents: Optional list of additional document IDs to include
auth: Authentication context
Returns:
Graph: The updated graph object
"""
try:
# Validate prompt overrides before proceeding
if request.prompt_overrides:
validate_prompt_overrides_with_http_exception(request.prompt_overrides, operation_type="graph")
async with telemetry.track_operation(
operation_type="update_graph",
user_id=auth.entity_id,
metadata={
"name": name,
"additional_filters": request.additional_filters,
"additional_documents": request.additional_documents,
},
):
return await document_service.update_graph(
name=name,
auth=auth,
additional_filters=request.additional_filters,
additional_documents=request.additional_documents,
prompt_overrides=request.prompt_overrides,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="graph", error=e)
except Exception as e:
logger.error(f"Error updating graph: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/local/generate_uri", include_in_schema=True)
async def generate_local_uri(
name: str = Form("admin"),
expiry_days: int = Form(30),
) -> Dict[str, str]:
"""Generate a local URI for development. This endpoint is unprotected."""
try:
# Clean name
name = name.replace(" ", "_").lower()
# Create payload
payload = {
"type": "developer",
"entity_id": name,
"permissions": ["read", "write", "admin"],
"exp": datetime.now(UTC) + timedelta(days=expiry_days),
}
# Generate token
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
# Read config for host/port
with open("morphik.toml", "rb") as f:
config = tomli.load(f)
base_url = f"{config['api']['host']}:{config['api']['port']}".replace(
"localhost", "127.0.0.1"
)
# Generate URI
uri = f"morphik://{name}:{token}@{base_url}"
return {"uri": uri}
except Exception as e:
logger.error(f"Error generating local URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/cloud/generate_uri", include_in_schema=True)
async def generate_cloud_uri(
request: GenerateUriRequest,
authorization: str = Header(None),
) -> Dict[str, str]:
"""Generate a URI for cloud hosted applications."""
try:
app_id = request.app_id
name = request.name
user_id = request.user_id
expiry_days = request.expiry_days
logger.debug(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
# Verify authorization header before proceeding
if not authorization:
logger.warning("Missing authorization header")
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify the token is valid
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode the token to ensure it's valid
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow users to create apps for themselves (or admin)
token_user_id = payload.get("user_id")
logger.debug(f"Token user ID: {token_user_id}")
logger.debug(f"User ID: {user_id}")
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
raise HTTPException(
status_code=403,
detail="You can only create apps for your own account unless you have admin permissions"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Import UserService here to avoid circular imports
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service if needed
await user_service.initialize()
# Clean name
name = name.replace(" ", "_").lower()
# Check if the user is within app limit and generate URI
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
if not uri:
logger.debug("Application limit reached for this account tier with user_id: %s", user_id)
raise HTTPException(
status_code=403,
detail="Application limit reached for this account tier"
)
return {"uri": uri, "app_id": app_id}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error generating cloud URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/user/upgrade", include_in_schema=True)
async def upgrade_user_tier(
user_id: str,
tier: str,
custom_limits: Optional[Dict[str, Any]] = None,
authorization: str = Header(None),
) -> Dict[str, Any]:
"""Upgrade a user to a higher tier."""
try:
# Verify admin authorization
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode token
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow admins to upgrade users
if "admin" not in payload.get("permissions", []):
raise HTTPException(
status_code=403,
detail="Admin permission required"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Validate tier
from core.models.tiers import AccountTier
try:
account_tier = AccountTier(tier)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}")
# Upgrade user
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service
await user_service.initialize()
# Update user tier
success = await user_service.update_user_tier(user_id, tier, custom_limits)
if not success:
raise HTTPException(
status_code=404,
detail="User not found or upgrade failed"
)
return {
"user_id": user_id,
"tier": tier,
"message": f"User upgraded to {tier} tier"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error upgrading user tier: {e}")
raise HTTPException(status_code=500, detail=str(e))