morphik-core/core/database/postgres_database.py
Adityavardhan Agrawal 273dfcc5e7
Add PostgreSQL support (#13)
Co-authored-by: Arnav Agrawal <aa779@cornell.edu>
2025-01-04 08:11:09 -05:00

316 lines
12 KiB
Python

from typing import List, Optional, Dict, Any
from datetime import datetime, UTC
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Index, select, text
from sqlalchemy.dialects.postgresql import JSONB
from .base_database import BaseDatabase
from ..models.documents import Document
from ..models.auth import AuthContext
logger = logging.getLogger(__name__)
Base = declarative_base()
class DocumentModel(Base):
"""SQLAlchemy model for document metadata."""
__tablename__ = "documents"
external_id = Column(String, primary_key=True)
owner = Column(JSONB)
content_type = Column(String)
filename = Column(String, nullable=True)
doc_metadata = Column(JSONB, default=dict)
storage_info = Column(JSONB, default=dict)
system_metadata = Column(JSONB, default=dict)
additional_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
chunk_ids = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_owner_id", "owner", postgresql_using="gin"),
Index("idx_access_control", "access_control", postgresql_using="gin"),
Index("idx_system_metadata", "system_metadata", postgresql_using="gin"),
)
def _serialize_datetime(obj: Any) -> Any:
"""Helper function to serialize datetime objects to ISO format strings."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: _serialize_datetime(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_serialize_datetime(item) for item in obj]
return obj
class PostgresDatabase(BaseDatabase):
"""PostgreSQL implementation for document metadata storage."""
def __init__(
self,
uri: str,
):
"""Initialize PostgreSQL connection for document storage."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
async def initialize(self):
"""Initialize database tables and indexes."""
try:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("PostgreSQL tables and indexes created successfully")
return True
except Exception as e:
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
return False
async def store_document(self, document: Document) -> bool:
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Rename metadata to doc_metadata
if "metadata" in doc_dict:
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
# Ensure system metadata
if "system_metadata" not in doc_dict:
doc_dict["system_metadata"] = {}
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
doc_dict = _serialize_datetime(doc_dict)
async with self.async_session() as session:
doc_model = DocumentModel(**doc_dict)
session.add(doc_model)
await session.commit()
return True
except Exception as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = (
select(DocumentModel)
.where(DocumentModel.external_id == document_id)
.where(text(f"({access_filter})"))
)
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = select(DocumentModel).where(text(f"({access_filter})"))
if metadata_filter:
query = query.where(text(metadata_filter))
query = query.offset(skip).limit(limit)
result = await session.execute(query)
doc_models = result.scalars().all()
return [
Document(
external_id=doc.external_id,
owner=doc.owner,
content_type=doc.content_type,
filename=doc.filename,
metadata=doc.doc_metadata,
storage_info=doc.storage_info,
system_metadata=doc.system_metadata,
additional_metadata=doc.additional_metadata,
access_control=doc.access_control,
chunk_ids=doc.chunk_ids,
)
for doc in doc_models
]
except Exception as e:
logger.error(f"Error listing documents: {str(e)}")
return []
async def update_document(
self, document_id: str, updates: Dict[str, Any], auth: AuthContext
) -> bool:
"""Update document metadata if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
# Update system metadata
updates.setdefault("system_metadata", {})
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
updates = _serialize_datetime(updates)
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if doc_model:
for key, value in updates.items():
setattr(doc_model, key, value)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""Delete document if user has admin access."""
try:
if not await self.check_access(document_id, auth, "admin"):
return False
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if doc_model:
await session.delete(doc_model)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error deleting document: {str(e)}")
return False
async def find_authorized_and_filtered_documents(
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
if metadata_filter:
query = query.where(text(metadata_filter))
result = await session.execute(query)
return [row[0] for row in result.all()]
except Exception as e:
logger.error(f"Error finding authorized documents: {str(e)}")
return []
async def check_access(
self, document_id: str, auth: AuthContext, required_permission: str = "read"
) -> bool:
"""Check if user has required permission for document."""
try:
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if not doc_model:
return False
# Check owner access
owner = doc_model.owner
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access
access_control = doc_model.access_control
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, [])
except Exception as e:
logger.error(f"Error checking document access: {str(e)}")
return False
def _build_access_filter(self, auth: AuthContext) -> str:
"""Build PostgreSQL filter for access control."""
filters = [
f"owner->>'id' = '{auth.entity_id}'",
f"access_control->'readers' ? '{auth.entity_id}'",
f"access_control->'writers' ? '{auth.entity_id}'",
f"access_control->'admins' ? '{auth.entity_id}'",
]
if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
return " OR ".join(filters)
def _build_metadata_filter(self, filters: Dict[str, Any]) -> str:
"""Build PostgreSQL filter for metadata."""
if not filters:
return ""
filter_conditions = []
for key, value in filters.items():
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)