mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
316 lines
12 KiB
Python
316 lines
12 KiB
Python
from typing import List, Optional, Dict, Any
|
|
from datetime import datetime, UTC
|
|
import logging
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy import Column, String, Index, select, text
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
|
|
from .base_database import BaseDatabase
|
|
from ..models.documents import Document
|
|
from ..models.auth import AuthContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
Base = declarative_base()
|
|
|
|
|
|
class DocumentModel(Base):
|
|
"""SQLAlchemy model for document metadata."""
|
|
|
|
__tablename__ = "documents"
|
|
|
|
external_id = Column(String, primary_key=True)
|
|
owner = Column(JSONB)
|
|
content_type = Column(String)
|
|
filename = Column(String, nullable=True)
|
|
doc_metadata = Column(JSONB, default=dict)
|
|
storage_info = Column(JSONB, default=dict)
|
|
system_metadata = Column(JSONB, default=dict)
|
|
additional_metadata = Column(JSONB, default=dict)
|
|
access_control = Column(JSONB, default=dict)
|
|
chunk_ids = Column(JSONB, default=list)
|
|
|
|
# Create indexes
|
|
__table_args__ = (
|
|
Index("idx_owner_id", "owner", postgresql_using="gin"),
|
|
Index("idx_access_control", "access_control", postgresql_using="gin"),
|
|
Index("idx_system_metadata", "system_metadata", postgresql_using="gin"),
|
|
)
|
|
|
|
|
|
def _serialize_datetime(obj: Any) -> Any:
|
|
"""Helper function to serialize datetime objects to ISO format strings."""
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
elif isinstance(obj, dict):
|
|
return {key: _serialize_datetime(value) for key, value in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [_serialize_datetime(item) for item in obj]
|
|
return obj
|
|
|
|
|
|
class PostgresDatabase(BaseDatabase):
|
|
"""PostgreSQL implementation for document metadata storage."""
|
|
|
|
def __init__(
|
|
self,
|
|
uri: str,
|
|
):
|
|
"""Initialize PostgreSQL connection for document storage."""
|
|
self.engine = create_async_engine(uri)
|
|
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
async def initialize(self):
|
|
"""Initialize database tables and indexes."""
|
|
try:
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
logger.info("PostgreSQL tables and indexes created successfully")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
|
|
return False
|
|
|
|
async def store_document(self, document: Document) -> bool:
|
|
"""Store document metadata."""
|
|
try:
|
|
doc_dict = document.model_dump()
|
|
|
|
# Rename metadata to doc_metadata
|
|
if "metadata" in doc_dict:
|
|
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
|
|
|
|
# Ensure system metadata
|
|
if "system_metadata" not in doc_dict:
|
|
doc_dict["system_metadata"] = {}
|
|
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
|
|
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
|
|
|
|
# Serialize datetime objects to ISO format strings
|
|
doc_dict = _serialize_datetime(doc_dict)
|
|
|
|
async with self.async_session() as session:
|
|
doc_model = DocumentModel(**doc_dict)
|
|
session.add(doc_model)
|
|
await session.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing document metadata: {str(e)}")
|
|
return False
|
|
|
|
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
|
|
"""Retrieve document metadata by ID if user has access."""
|
|
try:
|
|
async with self.async_session() as session:
|
|
# Build access filter
|
|
access_filter = self._build_access_filter(auth)
|
|
|
|
# Query document
|
|
query = (
|
|
select(DocumentModel)
|
|
.where(DocumentModel.external_id == document_id)
|
|
.where(text(f"({access_filter})"))
|
|
)
|
|
|
|
result = await session.execute(query)
|
|
doc_model = result.scalar_one_or_none()
|
|
|
|
if doc_model:
|
|
# Convert doc_metadata back to metadata
|
|
doc_dict = {
|
|
"external_id": doc_model.external_id,
|
|
"owner": doc_model.owner,
|
|
"content_type": doc_model.content_type,
|
|
"filename": doc_model.filename,
|
|
"metadata": doc_model.doc_metadata,
|
|
"storage_info": doc_model.storage_info,
|
|
"system_metadata": doc_model.system_metadata,
|
|
"additional_metadata": doc_model.additional_metadata,
|
|
"access_control": doc_model.access_control,
|
|
"chunk_ids": doc_model.chunk_ids,
|
|
}
|
|
return Document(**doc_dict)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving document metadata: {str(e)}")
|
|
return None
|
|
|
|
async def get_documents(
|
|
self,
|
|
auth: AuthContext,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
) -> List[Document]:
|
|
"""List documents the user has access to."""
|
|
try:
|
|
async with self.async_session() as session:
|
|
# Build query
|
|
access_filter = self._build_access_filter(auth)
|
|
metadata_filter = self._build_metadata_filter(filters)
|
|
|
|
query = select(DocumentModel).where(text(f"({access_filter})"))
|
|
if metadata_filter:
|
|
query = query.where(text(metadata_filter))
|
|
|
|
query = query.offset(skip).limit(limit)
|
|
|
|
result = await session.execute(query)
|
|
doc_models = result.scalars().all()
|
|
|
|
return [
|
|
Document(
|
|
external_id=doc.external_id,
|
|
owner=doc.owner,
|
|
content_type=doc.content_type,
|
|
filename=doc.filename,
|
|
metadata=doc.doc_metadata,
|
|
storage_info=doc.storage_info,
|
|
system_metadata=doc.system_metadata,
|
|
additional_metadata=doc.additional_metadata,
|
|
access_control=doc.access_control,
|
|
chunk_ids=doc.chunk_ids,
|
|
)
|
|
for doc in doc_models
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing documents: {str(e)}")
|
|
return []
|
|
|
|
async def update_document(
|
|
self, document_id: str, updates: Dict[str, Any], auth: AuthContext
|
|
) -> bool:
|
|
"""Update document metadata if user has write access."""
|
|
try:
|
|
if not await self.check_access(document_id, auth, "write"):
|
|
return False
|
|
|
|
# Update system metadata
|
|
updates.setdefault("system_metadata", {})
|
|
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
|
|
|
|
# Serialize datetime objects to ISO format strings
|
|
updates = _serialize_datetime(updates)
|
|
|
|
async with self.async_session() as session:
|
|
result = await session.execute(
|
|
select(DocumentModel).where(DocumentModel.external_id == document_id)
|
|
)
|
|
doc_model = result.scalar_one_or_none()
|
|
|
|
if doc_model:
|
|
for key, value in updates.items():
|
|
setattr(doc_model, key, value)
|
|
await session.commit()
|
|
return True
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating document metadata: {str(e)}")
|
|
return False
|
|
|
|
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
|
|
"""Delete document if user has admin access."""
|
|
try:
|
|
if not await self.check_access(document_id, auth, "admin"):
|
|
return False
|
|
|
|
async with self.async_session() as session:
|
|
result = await session.execute(
|
|
select(DocumentModel).where(DocumentModel.external_id == document_id)
|
|
)
|
|
doc_model = result.scalar_one_or_none()
|
|
|
|
if doc_model:
|
|
await session.delete(doc_model)
|
|
await session.commit()
|
|
return True
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting document: {str(e)}")
|
|
return False
|
|
|
|
async def find_authorized_and_filtered_documents(
|
|
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
|
|
) -> List[str]:
|
|
"""Find document IDs matching filters and access permissions."""
|
|
try:
|
|
async with self.async_session() as session:
|
|
# Build query
|
|
access_filter = self._build_access_filter(auth)
|
|
metadata_filter = self._build_metadata_filter(filters)
|
|
|
|
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
|
|
if metadata_filter:
|
|
query = query.where(text(metadata_filter))
|
|
|
|
result = await session.execute(query)
|
|
return [row[0] for row in result.all()]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error finding authorized documents: {str(e)}")
|
|
return []
|
|
|
|
async def check_access(
|
|
self, document_id: str, auth: AuthContext, required_permission: str = "read"
|
|
) -> bool:
|
|
"""Check if user has required permission for document."""
|
|
try:
|
|
async with self.async_session() as session:
|
|
result = await session.execute(
|
|
select(DocumentModel).where(DocumentModel.external_id == document_id)
|
|
)
|
|
doc_model = result.scalar_one_or_none()
|
|
|
|
if not doc_model:
|
|
return False
|
|
|
|
# Check owner access
|
|
owner = doc_model.owner
|
|
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
|
|
return True
|
|
|
|
# Check permission-specific access
|
|
access_control = doc_model.access_control
|
|
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
|
|
permission_set = permission_map.get(required_permission)
|
|
|
|
if not permission_set:
|
|
return False
|
|
|
|
return auth.entity_id in access_control.get(permission_set, [])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking document access: {str(e)}")
|
|
return False
|
|
|
|
def _build_access_filter(self, auth: AuthContext) -> str:
|
|
"""Build PostgreSQL filter for access control."""
|
|
filters = [
|
|
f"owner->>'id' = '{auth.entity_id}'",
|
|
f"access_control->'readers' ? '{auth.entity_id}'",
|
|
f"access_control->'writers' ? '{auth.entity_id}'",
|
|
f"access_control->'admins' ? '{auth.entity_id}'",
|
|
]
|
|
|
|
if auth.entity_type == "DEVELOPER" and auth.app_id:
|
|
# Add app-specific access for developers
|
|
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
|
|
|
|
return " OR ".join(filters)
|
|
|
|
def _build_metadata_filter(self, filters: Dict[str, Any]) -> str:
|
|
"""Build PostgreSQL filter for metadata."""
|
|
if not filters:
|
|
return ""
|
|
|
|
filter_conditions = []
|
|
for key, value in filters.items():
|
|
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
|
|
|
|
return " AND ".join(filter_conditions)
|