morphik-core/core/database/postgres_database.py
2025-03-17 17:36:43 -04:00

708 lines
27 KiB
Python

import json
from typing import List, Optional, Dict, Any
from datetime import datetime, UTC
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Index, select, text
from sqlalchemy.dialects.postgresql import JSONB
from .base_database import BaseDatabase
from ..models.documents import Document
from ..models.auth import AuthContext
from ..models.graph import Graph, Entity, Relationship
logger = logging.getLogger(__name__)
Base = declarative_base()
class DocumentModel(Base):
"""SQLAlchemy model for document metadata."""
__tablename__ = "documents"
external_id = Column(String, primary_key=True)
owner = Column(JSONB)
content_type = Column(String)
filename = Column(String, nullable=True)
doc_metadata = Column(JSONB, default=dict)
storage_info = Column(JSONB, default=dict)
system_metadata = Column(JSONB, default=dict)
additional_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
chunk_ids = Column(JSONB, default=list)
storage_files = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_owner_id", "owner", postgresql_using="gin"),
Index("idx_access_control", "access_control", postgresql_using="gin"),
Index("idx_system_metadata", "system_metadata", postgresql_using="gin"),
)
class GraphModel(Base):
"""SQLAlchemy model for graph data."""
__tablename__ = "graphs"
id = Column(String, primary_key=True)
name = Column(String, unique=True, index=True)
entities = Column(JSONB, default=list)
relationships = Column(JSONB, default=list)
graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict
document_ids = Column(JSONB, default=list)
filters = Column(JSONB, nullable=True)
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
owner = Column(JSONB)
access_control = Column(JSONB, default=dict)
# Create indexes
__table_args__ = (
Index("idx_graph_name", "name"),
Index("idx_graph_owner", "owner", postgresql_using="gin"),
Index("idx_graph_access_control", "access_control", postgresql_using="gin"),
)
def _serialize_datetime(obj: Any) -> Any:
"""Helper function to serialize datetime objects to ISO format strings."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: _serialize_datetime(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_serialize_datetime(item) for item in obj]
return obj
class PostgresDatabase(BaseDatabase):
"""PostgreSQL implementation for document metadata storage."""
def __init__(
self,
uri: str,
):
"""Initialize PostgreSQL connection for document storage."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self):
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing PostgreSQL database tables and indexes...")
# Create ORM models
async with self.engine.begin() as conn:
# Explicitly create all tables with checkfirst=True to avoid errors if tables already exist
await conn.run_sync(lambda conn: Base.metadata.create_all(conn, checkfirst=True))
# No need to manually create graphs table again since SQLAlchemy does it
logger.info("Created database tables successfully")
# Create caches table if it doesn't exist (kept as direct SQL for backward compatibility)
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Check if storage_files column exists
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'storage_files'
"""
)
)
if not result.first():
# Add storage_files column to documents table
await conn.execute(
text(
"""
ALTER TABLE documents
ADD COLUMN IF NOT EXISTS storage_files JSONB DEFAULT '[]'::jsonb
"""
)
)
logger.info("Added storage_files column to documents table")
logger.info("PostgreSQL tables and indexes created successfully")
self._initialized = True
return True
except Exception as e:
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
return False
async def store_document(self, document: Document) -> bool:
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Rename metadata to doc_metadata
if "metadata" in doc_dict:
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"]
# Ensure system metadata
if "system_metadata" not in doc_dict:
doc_dict["system_metadata"] = {}
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
# Handle storage_files
if "storage_files" in doc_dict and doc_dict["storage_files"]:
# Convert storage_files to the expected format for storage
doc_dict["storage_files"] = [file.model_dump() for file in doc_dict["storage_files"]]
# Serialize datetime objects to ISO format strings
doc_dict = _serialize_datetime(doc_dict)
async with self.async_session() as session:
doc_model = DocumentModel(**doc_dict)
session.add(doc_model)
await session.commit()
return True
except Exception as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = (
select(DocumentModel)
.where(DocumentModel.external_id == document_id)
.where(text(f"({access_filter})"))
)
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": doc_model.storage_files or [],
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
"""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = (
select(DocumentModel)
.where(DocumentModel.filename == filename)
.where(text(f"({access_filter})"))
# Order by updated_at in system_metadata to get the most recent document
.order_by(text("system_metadata->>'updated_at' DESC"))
)
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": doc_model.storage_files or [],
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
return None
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query documents with both document IDs and access check in a single query
query = (
select(DocumentModel)
.where(DocumentModel.external_id.in_(document_ids))
.where(text(f"({access_filter})"))
)
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
result = await session.execute(query)
doc_models = result.scalars().all()
documents = []
for doc_model in doc_models:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": doc_model.storage_files or [],
}
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except Exception as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = select(DocumentModel).where(text(f"({access_filter})"))
if metadata_filter:
query = query.where(text(metadata_filter))
query = query.offset(skip).limit(limit)
result = await session.execute(query)
doc_models = result.scalars().all()
return [
Document(
external_id=doc.external_id,
owner=doc.owner,
content_type=doc.content_type,
filename=doc.filename,
metadata=doc.doc_metadata,
storage_info=doc.storage_info,
system_metadata=doc.system_metadata,
additional_metadata=doc.additional_metadata,
access_control=doc.access_control,
chunk_ids=doc.chunk_ids,
storage_files=doc.storage_files or [],
)
for doc in doc_models
]
except Exception as e:
logger.error(f"Error listing documents: {str(e)}")
return []
async def update_document(
self, document_id: str, updates: Dict[str, Any], auth: AuthContext
) -> bool:
"""Update document metadata if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
# Update system metadata
updates.setdefault("system_metadata", {})
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
updates = _serialize_datetime(updates)
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if doc_model:
for key, value in updates.items():
setattr(doc_model, key, value)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""Delete document if user has admin access."""
try:
if not await self.check_access(document_id, auth, "admin"):
return False
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if doc_model:
await session.delete(doc_model)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error deleting document: {str(e)}")
return False
async def find_authorized_and_filtered_documents(
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
logger.debug(f"Access filter: {access_filter}")
logger.debug(f"Metadata filter: {metadata_filter}")
logger.debug(f"Original filters: {filters}")
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
if metadata_filter:
query = query.where(text(metadata_filter))
logger.debug(f"Final query: {query}")
result = await session.execute(query)
doc_ids = [row[0] for row in result.all()]
logger.debug(f"Found document IDs: {doc_ids}")
return doc_ids
except Exception as e:
logger.error(f"Error finding authorized documents: {str(e)}")
return []
async def check_access(
self, document_id: str, auth: AuthContext, required_permission: str = "read"
) -> bool:
"""Check if user has required permission for document."""
try:
async with self.async_session() as session:
result = await session.execute(
select(DocumentModel).where(DocumentModel.external_id == document_id)
)
doc_model = result.scalar_one_or_none()
if not doc_model:
return False
# Check owner access
owner = doc_model.owner
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access
access_control = doc_model.access_control
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, [])
except Exception as e:
logger.error(f"Error checking document access: {str(e)}")
return False
def _build_access_filter(self, auth: AuthContext) -> str:
"""Build PostgreSQL filter for access control."""
filters = [
f"owner->>'id' = '{auth.entity_id}'",
f"access_control->'readers' ? '{auth.entity_id}'",
f"access_control->'writers' ? '{auth.entity_id}'",
f"access_control->'admins' ? '{auth.entity_id}'",
]
if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
return " OR ".join(filters)
def _build_metadata_filter(self, filters: Dict[str, Any]) -> str:
"""Build PostgreSQL filter for metadata."""
if not filters:
return ""
filter_conditions = []
for key, value in filters.items():
# Convert boolean values to string 'true' or 'false'
if isinstance(value, bool):
value = str(value).lower()
# Use proper SQL escaping for string values
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
value = value.replace("'", "''")
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in PostgreSQL.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
async with self.async_session() as session:
await session.execute(
text(
"""
INSERT INTO caches (name, metadata, updated_at)
VALUES (:name, :metadata, CURRENT_TIMESTAMP)
ON CONFLICT (name)
DO UPDATE SET
metadata = :metadata,
updated_at = CURRENT_TIMESTAMP
"""
),
{"name": name, "metadata": json.dumps(metadata)},
)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from PostgreSQL.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
async with self.async_session() as session:
result = await session.execute(
text("SELECT metadata FROM caches WHERE name = :name"), {"name": name}
)
row = result.first()
return row[0] if row else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None
async def store_graph(self, graph: Graph) -> bool:
"""Store a graph in PostgreSQL.
This method stores the graph metadata, entities, and relationships
in a PostgreSQL table.
Args:
graph: Graph to store
Returns:
bool: Whether the operation was successful
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
# First serialize the graph model to dict
graph_dict = graph.model_dump()
# Change 'metadata' to 'graph_metadata' to match our model
if "metadata" in graph_dict:
graph_dict["graph_metadata"] = graph_dict.pop("metadata")
# Serialize datetime objects to ISO format strings
graph_dict = _serialize_datetime(graph_dict)
# Store the graph metadata in PostgreSQL
async with self.async_session() as session:
# Store graph metadata in our table
graph_model = GraphModel(**graph_dict)
session.add(graph_model)
await session.commit()
logger.info(f"Stored graph '{graph.name}' with {len(graph.entities)} entities and {len(graph.relationships)} relationships")
return True
except Exception as e:
logger.error(f"Error storing graph: {str(e)}")
return False
async def get_graph(self, name: str, auth: AuthContext) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query graph
query = (
select(GraphModel)
.where(GraphModel.name == name)
.where(text(f"({access_filter})"))
)
result = await session.execute(query)
graph_model = result.scalar_one_or_none()
if graph_model:
# Convert to Graph model
graph_dict = {
"id": graph_model.id,
"name": graph_model.name,
"entities": graph_model.entities,
"relationships": graph_model.relationships,
"metadata": graph_model.graph_metadata, # Reference the renamed column
"document_ids": graph_model.document_ids,
"filters": graph_model.filters,
"created_at": graph_model.created_at,
"updated_at": graph_model.updated_at,
"owner": graph_model.owner,
"access_control": graph_model.access_control,
}
return Graph(**graph_dict)
return None
except Exception as e:
logger.error(f"Error retrieving graph: {str(e)}")
return None
async def list_graphs(self, auth: AuthContext) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
Returns:
List[Graph]: List of graphs
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query graphs
query = select(GraphModel).where(text(f"({access_filter})"))
result = await session.execute(query)
graph_models = result.scalars().all()
return [
Graph(
id=graph.id,
name=graph.name,
entities=graph.entities,
relationships=graph.relationships,
metadata=graph.graph_metadata, # Reference the renamed column
document_ids=graph.document_ids,
filters=graph.filters,
created_at=graph.created_at,
updated_at=graph.updated_at,
owner=graph.owner,
access_control=graph.access_control,
)
for graph in graph_models
]
except Exception as e:
logger.error(f"Error listing graphs: {str(e)}")
return []