morphik-core/core/database/postgres_database.py

1544 lines
63 KiB
Python

import json
import logging
from datetime import UTC, datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import Column, Index, String, select, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from ..models.auth import AuthContext
from ..models.documents import Document, StorageFileInfo
from ..models.folders import Folder
from ..models.graph import Graph
from .base_database import BaseDatabase
logger = logging.getLogger(__name__)
Base = declarative_base()
class DocumentModel(Base):
"""SQLAlchemy model for document metadata."""
__tablename__ = "documents"
external_id = Column(String, primary_key=True)
owner = Column(JSONB)
content_type = Column(String)
filename = Column(String, nullable=True)
doc_metadata = Column(JSONB, default=dict)
storage_info = Column(JSONB, default=dict)
system_metadata = Column(JSONB, default=dict)
additional_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
chunk_ids = Column(JSONB, default=list)
storage_files = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_owner_id", "owner", postgresql_using="gin"),
Index("idx_access_control", "access_control", postgresql_using="gin"),
Index("idx_system_metadata", "system_metadata", postgresql_using="gin"),
)
class GraphModel(Base):
"""SQLAlchemy model for graph data."""
__tablename__ = "graphs"
id = Column(String, primary_key=True)
name = Column(String, index=True) # Not unique globally anymore
entities = Column(JSONB, default=list)
relationships = Column(JSONB, default=list)
graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict
system_metadata = Column(JSONB, default=dict) # For folder_name and end_user_id
document_ids = Column(JSONB, default=list)
filters = Column(JSONB, nullable=True)
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
owner = Column(JSONB)
access_control = Column(JSONB, default=dict)
# Create indexes
__table_args__ = (
Index("idx_graph_name", "name"),
Index("idx_graph_owner", "owner", postgresql_using="gin"),
Index("idx_graph_access_control", "access_control", postgresql_using="gin"),
Index("idx_graph_system_metadata", "system_metadata", postgresql_using="gin"),
# Create a unique constraint on name scoped by owner ID
Index("idx_graph_owner_name", "name", text("(owner->>'id')"), unique=True),
)
class FolderModel(Base):
"""SQLAlchemy model for folder data."""
__tablename__ = "folders"
id = Column(String, primary_key=True)
name = Column(String, index=True)
description = Column(String, nullable=True)
owner = Column(JSONB)
document_ids = Column(JSONB, default=list)
system_metadata = Column(JSONB, default=dict)
access_control = Column(JSONB, default=dict)
rules = Column(JSONB, default=list)
# Create indexes
__table_args__ = (
Index("idx_folder_name", "name"),
Index("idx_folder_owner", "owner", postgresql_using="gin"),
Index("idx_folder_access_control", "access_control", postgresql_using="gin"),
)
def _serialize_datetime(obj: Any) -> Any:
"""Helper function to serialize datetime objects to ISO format strings."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: _serialize_datetime(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [_serialize_datetime(item) for item in obj]
return obj
class PostgresDatabase(BaseDatabase):
"""PostgreSQL implementation for document metadata storage."""
def __init__(
self,
uri: str,
):
"""Initialize PostgreSQL connection for document storage."""
# Load settings from config
from core.config import get_settings
settings = get_settings()
# Get database pool settings from config with defaults
pool_size = getattr(settings, "DB_POOL_SIZE", 20)
max_overflow = getattr(settings, "DB_MAX_OVERFLOW", 30)
pool_recycle = getattr(settings, "DB_POOL_RECYCLE", 3600)
pool_timeout = getattr(settings, "DB_POOL_TIMEOUT", 10)
pool_pre_ping = getattr(settings, "DB_POOL_PRE_PING", True)
logger.info(
f"Initializing PostgreSQL connection pool with size={pool_size}, "
f"max_overflow={max_overflow}, pool_recycle={pool_recycle}s"
)
# Create async engine with explicit pool settings
self.engine = create_async_engine(
uri,
# Prevent connection timeouts by keeping connections alive
pool_pre_ping=pool_pre_ping,
# Increase pool size to handle concurrent operations
pool_size=pool_size,
# Maximum overflow connections allowed beyond pool_size
max_overflow=max_overflow,
# Keep connections in the pool for up to 60 minutes
pool_recycle=pool_recycle,
# Time to wait for a connection from the pool (10 seconds)
pool_timeout=pool_timeout,
# Echo SQL for debugging (set to False in production)
echo=False,
)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self):
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing PostgreSQL database tables and indexes...")
# Create ORM models
async with self.engine.begin() as conn:
# Explicitly create all tables with checkfirst=True to avoid errors if tables already exist
await conn.run_sync(lambda conn: Base.metadata.create_all(conn, checkfirst=True))
# No need to manually create graphs table again since SQLAlchemy does it
logger.info("Created database tables successfully")
# Create caches table if it doesn't exist (kept as direct SQL for backward compatibility)
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Check if storage_files column exists
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'storage_files'
"""
)
)
if not result.first():
# Add storage_files column to documents table
await conn.execute(
text(
"""
ALTER TABLE documents
ADD COLUMN IF NOT EXISTS storage_files JSONB DEFAULT '[]'::jsonb
"""
)
)
logger.info("Added storage_files column to documents table")
# Create indexes for folder_name and end_user_id in system_metadata for documents
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_folder_name
ON documents ((system_metadata->>'folder_name'));
"""
)
)
# Create folders table if it doesn't exist
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS folders (
id TEXT PRIMARY KEY,
name TEXT,
description TEXT,
owner JSONB,
document_ids JSONB DEFAULT '[]',
system_metadata JSONB DEFAULT '{}',
access_control JSONB DEFAULT '{}'
);
"""
)
)
# Add rules column to folders table if it doesn't exist
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'folders' AND column_name = 'rules'
"""
)
)
if not result.first():
# Add rules column to folders table
await conn.execute(
text(
"""
ALTER TABLE folders
ADD COLUMN IF NOT EXISTS rules JSONB DEFAULT '[]'::jsonb
"""
)
)
logger.info("Added rules column to folders table")
# Create indexes for folders table
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_name ON folders (name);"))
await conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folder_owner ON folders USING gin (owner);"))
await conn.execute(
text("CREATE INDEX IF NOT EXISTS idx_folder_access_control ON folders USING gin (access_control);")
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_end_user_id
ON documents ((system_metadata->>'end_user_id'));
"""
)
)
# Check if system_metadata column exists in graphs table
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'graphs' AND column_name = 'system_metadata'
"""
)
)
if not result.first():
# Add system_metadata column to graphs table
await conn.execute(
text(
"""
ALTER TABLE graphs
ADD COLUMN IF NOT EXISTS system_metadata JSONB DEFAULT '{}'::jsonb
"""
)
)
logger.info("Added system_metadata column to graphs table")
# Create indexes for folder_name and end_user_id in system_metadata for graphs
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_folder_name
ON graphs ((system_metadata->>'folder_name'));
"""
)
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_end_user_id
ON graphs ((system_metadata->>'end_user_id'));
"""
)
)
logger.info("Created indexes for folder_name and end_user_id in system_metadata")
logger.info("PostgreSQL tables and indexes created successfully")
self._initialized = True
return True
except Exception as e:
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
return False
async def store_document(self, document: Document) -> bool:
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Rename metadata to doc_metadata
if "metadata" in doc_dict:
doc_dict["doc_metadata"] = doc_dict.pop("metadata")
doc_dict["doc_metadata"]["external_id"] = doc_dict["external_id"]
# Ensure system metadata
if "system_metadata" not in doc_dict:
doc_dict["system_metadata"] = {}
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
# Handle storage_files
if "storage_files" in doc_dict and doc_dict["storage_files"]:
# Convert storage_files to the expected format for storage
doc_dict["storage_files"] = [file.model_dump() for file in doc_dict["storage_files"]]
# Serialize datetime objects to ISO format strings
doc_dict = _serialize_datetime(doc_dict)
async with self.async_session() as session:
doc_model = DocumentModel(**doc_dict)
session.add(doc_model)
await session.commit()
return True
except Exception as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = (
select(DocumentModel)
.where(DocumentModel.external_id == document_id)
.where(text(f"({access_filter})"))
)
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
# Also convert storage_files from dict to StorageFileInfo
storage_files = []
if doc_model.storage_files:
for file_info in doc_model.storage_files:
if isinstance(file_info, dict):
storage_files.append(StorageFileInfo(**file_info))
else:
storage_files.append(file_info)
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": storage_files,
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_document_by_filename(
self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Document]:
"""Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
Args:
filename: The filename to search for
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
"""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
filename = filename.replace("'", "''")
# Construct where clauses
where_clauses = [
f"({access_filter})",
f"filename = '{filename}'", # Escape single quotes
]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query document with system filters
query = (
select(DocumentModel).where(text(final_where_clause))
# Order by updated_at in system_metadata to get the most recent document
.order_by(text("system_metadata->>'updated_at' DESC"))
)
logger.debug(f"Querying document by filename with system filters: {system_filters}")
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
if doc_model:
# Convert doc_metadata back to metadata
# Also convert storage_files from dict to StorageFileInfo
storage_files = []
if doc_model.storage_files:
for file_info in doc_model.storage_files:
if isinstance(file_info, dict):
storage_files.append(StorageFileInfo(**file_info))
else:
storage_files.append(file_info)
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": storage_files,
}
return Document(**doc_dict)
return None
except Exception as e:
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
return None
async def get_documents_by_id(
self,
document_ids: List[str],
auth: AuthContext,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Can filter by system metadata fields like folder_name and end_user_id.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
system_filters: Optional filters for system metadata fields
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
# Construct where clauses
document_ids_linked = ", ".join([("'" + doc_id + "'") for doc_id in document_ids])
where_clauses = [f"({access_filter})", f"external_id IN ({document_ids_linked})"]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query documents with document IDs, access check, and system filters in a single query
query = select(DocumentModel).where(text(final_where_clause))
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
result = await session.execute(query)
doc_models = result.scalars().all()
documents = []
for doc_model in doc_models:
# Convert doc_metadata back to metadata
doc_dict = {
"external_id": doc_model.external_id,
"owner": doc_model.owner,
"content_type": doc_model.content_type,
"filename": doc_model.filename,
"metadata": doc_model.doc_metadata,
"storage_info": doc_model.storage_info,
"system_metadata": doc_model.system_metadata,
"additional_metadata": doc_model.additional_metadata,
"access_control": doc_model.access_control,
"chunk_ids": doc_model.chunk_ids,
"storage_files": doc_model.storage_files or [],
}
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except Exception as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
where_clauses = [f"({access_filter})"]
if metadata_filter:
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel).where(text(final_where_clause))
query = query.offset(skip).limit(limit)
result = await session.execute(query)
doc_models = result.scalars().all()
return [
Document(
external_id=doc.external_id,
owner=doc.owner,
content_type=doc.content_type,
filename=doc.filename,
metadata=doc.doc_metadata,
storage_info=doc.storage_info,
system_metadata=doc.system_metadata,
additional_metadata=doc.additional_metadata,
access_control=doc.access_control,
chunk_ids=doc.chunk_ids,
storage_files=doc.storage_files or [],
)
for doc in doc_models
]
except Exception as e:
logger.error(f"Error listing documents: {str(e)}")
return []
async def update_document(self, document_id: str, updates: Dict[str, Any], auth: AuthContext) -> bool:
"""Update document metadata if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
# Get existing document to preserve system_metadata
existing_doc = await self.get_document(document_id, auth)
if not existing_doc:
return False
# Update system metadata
updates.setdefault("system_metadata", {})
# Merge with existing system_metadata instead of just preserving specific fields
if existing_doc.system_metadata:
# Start with existing system_metadata
merged_system_metadata = dict(existing_doc.system_metadata)
# Update with new values
merged_system_metadata.update(updates["system_metadata"])
# Replace with merged result
updates["system_metadata"] = merged_system_metadata
logger.debug("Merged system_metadata during document update, preserving existing fields")
# Always update the updated_at timestamp
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
updates = _serialize_datetime(updates)
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if doc_model:
# Log what we're updating
logger.info(f"Document update: updating fields {list(updates.keys())}")
# Special handling for metadata/doc_metadata conversion
if "metadata" in updates and "doc_metadata" not in updates:
logger.info("Converting 'metadata' to 'doc_metadata' for database update")
updates["doc_metadata"] = updates.pop("metadata")
# Set all attributes
for key, value in updates.items():
if key == "storage_files" and isinstance(value, list):
serialized_value = [
_serialize_datetime(
item.model_dump()
if hasattr(item, "model_dump")
else (item.dict() if hasattr(item, "dict") else item)
)
for item in value
]
logger.debug("Serializing storage_files before setting attribute")
setattr(doc_model, key, serialized_value)
else:
logger.debug(f"Setting document attribute {key} = {value}")
setattr(doc_model, key, value)
await session.commit()
logger.info(f"Document {document_id} updated successfully")
return True
return False
except Exception as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""Delete document if user has write access."""
try:
if not await self.check_access(document_id, auth, "write"):
return False
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if doc_model:
await session.delete(doc_model)
await session.commit()
return True
return False
except Exception as e:
logger.error(f"Error deleting document: {str(e)}")
return False
async def find_authorized_and_filtered_documents(
self,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
try:
async with self.async_session() as session:
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
logger.debug(f"Access filter: {access_filter}")
logger.debug(f"Metadata filter: {metadata_filter}")
logger.debug(f"System metadata filter: {system_metadata_filter}")
logger.debug(f"Original filters: {filters}")
logger.debug(f"System filters: {system_filters}")
where_clauses = [f"({access_filter})"]
if metadata_filter:
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel.external_id).where(text(final_where_clause))
logger.debug(f"Final query: {query}")
result = await session.execute(query)
doc_ids = [row[0] for row in result.all()]
logger.debug(f"Found document IDs: {doc_ids}")
return doc_ids
except Exception as e:
logger.error(f"Error finding authorized documents: {str(e)}")
return []
async def check_access(self, document_id: str, auth: AuthContext, required_permission: str = "read") -> bool:
"""Check if user has required permission for document."""
try:
async with self.async_session() as session:
result = await session.execute(select(DocumentModel).where(DocumentModel.external_id == document_id))
doc_model = result.scalar_one_or_none()
if not doc_model:
return False
# Check owner access
owner = doc_model.owner
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access
access_control = doc_model.access_control
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, [])
except Exception as e:
logger.error(f"Error checking document access: {str(e)}")
return False
def _build_access_filter(self, auth: AuthContext) -> str:
"""Build PostgreSQL filter for access control."""
filters = [
f"owner->>'id' = '{auth.entity_id}'",
f"access_control->'readers' ? '{auth.entity_id}'",
f"access_control->'writers' ? '{auth.entity_id}'",
f"access_control->'admins' ? '{auth.entity_id}'",
]
if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
# Add user_id filter in cloud mode
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
# Filter by user_id in access_control
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
return " OR ".join(filters)
def _build_metadata_filter(self, filters: Dict[str, Any]) -> str:
"""Build PostgreSQL filter for metadata."""
if not filters:
return ""
filter_conditions = []
for key, value in filters.items():
# Handle list of values (IN operator)
if isinstance(value, list):
if not value: # Skip empty lists
continue
# Build a list of properly escaped values
escaped_values = []
for item in value:
if isinstance(item, bool):
escaped_values.append(str(item).lower())
elif isinstance(item, str):
# Use standard replace, avoid complex f-string quoting for black
escaped_value = item.replace("'", "''")
escaped_values.append(f"'{escaped_value}'")
else:
escaped_values.append(f"'{item}'")
# Join with commas for IN clause
values_str = ", ".join(escaped_values)
filter_conditions.append(f"doc_metadata->>'{key}' IN ({values_str})")
else:
# Handle single value (equality)
# Convert boolean values to string 'true' or 'false'
if isinstance(value, bool):
value = str(value).lower()
# Use proper SQL escaping for string values
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
value = value.replace("'", "''")
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)
def _build_system_metadata_filter(self, system_filters: Optional[Dict[str, Any]]) -> str:
"""Build PostgreSQL filter for system metadata."""
if not system_filters:
return ""
conditions = []
for key, value in system_filters.items():
if value is None:
continue
# Handle list of values (IN operator)
if isinstance(value, list):
if not value: # Skip empty lists
continue
# Build a list of properly escaped values
escaped_values = []
for item in value:
if isinstance(item, bool):
escaped_values.append(str(item).lower())
elif isinstance(item, str):
# Use standard replace, avoid complex f-string quoting for black
escaped_value = item.replace("'", "''")
escaped_values.append(f"'{escaped_value}'")
else:
escaped_values.append(f"'{item}'")
# Join with commas for IN clause
values_str = ", ".join(escaped_values)
conditions.append(f"system_metadata->>'{key}' IN ({values_str})")
else:
# Handle single value (equality)
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
escaped_value = value.replace("'", "''")
conditions.append(f"system_metadata->>'{key}' = '{escaped_value}'")
elif isinstance(value, bool):
conditions.append(f"system_metadata->>'{key}' = '{str(value).lower()}'")
else:
conditions.append(f"system_metadata->>'{key}' = '{value}'")
return " AND ".join(conditions)
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in PostgreSQL.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
async with self.async_session() as session:
await session.execute(
text(
"""
INSERT INTO caches (name, metadata, updated_at)
VALUES (:name, :metadata, CURRENT_TIMESTAMP)
ON CONFLICT (name)
DO UPDATE SET
metadata = :metadata,
updated_at = CURRENT_TIMESTAMP
"""
),
{"name": name, "metadata": json.dumps(metadata)},
)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from PostgreSQL.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
async with self.async_session() as session:
result = await session.execute(text("SELECT metadata FROM caches WHERE name = :name"), {"name": name})
row = result.first()
return row[0] if row else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None
async def store_graph(self, graph: Graph) -> bool:
"""Store a graph in PostgreSQL.
This method stores the graph metadata, entities, and relationships
in a PostgreSQL table.
Args:
graph: Graph to store
Returns:
bool: Whether the operation was successful
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
# First serialize the graph model to dict
graph_dict = graph.model_dump()
# Change 'metadata' to 'graph_metadata' to match our model
if "metadata" in graph_dict:
graph_dict["graph_metadata"] = graph_dict.pop("metadata")
# Serialize datetime objects to ISO format strings
graph_dict = _serialize_datetime(graph_dict)
# Store the graph metadata in PostgreSQL
async with self.async_session() as session:
# Store graph metadata in our table
graph_model = GraphModel(**graph_dict)
session.add(graph_model)
await session.commit()
logger.info(
f"Stored graph '{graph.name}' with {len(graph.entities)} entities "
f"and {len(graph.relationships)} relationships"
)
return True
except Exception as e:
logger.error(f"Error storing graph: {str(e)}")
return False
async def get_graph(
self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None
) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# We need to check if the documents in the graph match the system filters
# First get the graph without system filters
query = select(GraphModel).where(GraphModel.name == name).where(text(f"({access_filter})"))
result = await session.execute(query)
graph_model = result.scalar_one_or_none()
if graph_model:
# If system filters are provided, we need to filter the document_ids
document_ids = graph_model.document_ids
if system_filters and document_ids:
# Apply system_filters to document_ids
system_metadata_filter = self._build_system_metadata_filter(system_filters)
if system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# If no documents match system filters, return None
if not filtered_doc_ids:
return None
# Update document_ids with filtered results
document_ids = filtered_doc_ids
# Convert to Graph model
graph_dict = {
"id": graph_model.id,
"name": graph_model.name,
"entities": graph_model.entities,
"relationships": graph_model.relationships,
"metadata": graph_model.graph_metadata, # Reference the renamed column
"system_metadata": graph_model.system_metadata or {}, # Include system_metadata
"document_ids": document_ids, # Use possibly filtered document_ids
"filters": graph_model.filters,
"created_at": graph_model.created_at,
"updated_at": graph_model.updated_at,
"owner": graph_model.owner,
"access_control": graph_model.access_control,
}
return Graph(**graph_dict)
return None
except Exception as e:
logger.error(f"Error retrieving graph: {str(e)}")
return None
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List[Graph]: List of graphs
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query graphs
query = select(GraphModel).where(text(f"({access_filter})"))
result = await session.execute(query)
graph_models = result.scalars().all()
graphs = []
# If system filters are provided, we need to filter each graph's document_ids
if system_filters:
system_metadata_filter = self._build_system_metadata_filter(system_filters)
for graph_model in graph_models:
document_ids = graph_model.document_ids
if document_ids and system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# Only include graphs that have documents matching the system filters
if filtered_doc_ids:
graph = Graph(
id=graph_model.id,
name=graph_model.name,
entities=graph_model.entities,
relationships=graph_model.relationships,
metadata=graph_model.graph_metadata, # Reference the renamed column
system_metadata=graph_model.system_metadata or {}, # Include system_metadata
document_ids=filtered_doc_ids, # Use filtered document_ids
filters=graph_model.filters,
created_at=graph_model.created_at,
updated_at=graph_model.updated_at,
owner=graph_model.owner,
access_control=graph_model.access_control,
)
graphs.append(graph)
else:
# No system filters, include all graphs
graphs = [
Graph(
id=graph.id,
name=graph.name,
entities=graph.entities,
relationships=graph.relationships,
metadata=graph.graph_metadata, # Reference the renamed column
system_metadata=graph.system_metadata or {}, # Include system_metadata
document_ids=graph.document_ids,
filters=graph.filters,
created_at=graph.created_at,
updated_at=graph.updated_at,
owner=graph.owner,
access_control=graph.access_control,
)
for graph in graph_models
]
return graphs
except Exception as e:
logger.error(f"Error listing graphs: {str(e)}")
return []
async def update_graph(self, graph: Graph) -> bool:
"""Update an existing graph in PostgreSQL.
This method updates the graph metadata, entities, and relationships
in the PostgreSQL table.
Args:
graph: Graph to update
Returns:
bool: Whether the operation was successful
"""
# Ensure database is initialized
if not self._initialized:
await self.initialize()
try:
# First serialize the graph model to dict
graph_dict = graph.model_dump()
# Change 'metadata' to 'graph_metadata' to match our model
if "metadata" in graph_dict:
graph_dict["graph_metadata"] = graph_dict.pop("metadata")
# Serialize datetime objects to ISO format strings
graph_dict = _serialize_datetime(graph_dict)
# Update the graph in PostgreSQL
async with self.async_session() as session:
# Check if the graph exists
result = await session.execute(select(GraphModel).where(GraphModel.id == graph.id))
graph_model = result.scalar_one_or_none()
if not graph_model:
logger.error(f"Graph '{graph.name}' with ID {graph.id} not found for update")
return False
# Update the graph model with new values
for key, value in graph_dict.items():
setattr(graph_model, key, value)
await session.commit()
logger.info(
f"Updated graph '{graph.name}' with {len(graph.entities)} entities "
f"and {len(graph.relationships)} relationships"
)
return True
except Exception as e:
logger.error(f"Error updating graph: {str(e)}")
return False
async def create_folder(self, folder: Folder) -> bool:
"""Create a new folder."""
try:
async with self.async_session() as session:
folder_dict = folder.model_dump()
# Convert datetime objects to strings for JSON serialization
folder_dict = _serialize_datetime(folder_dict)
# Check if a folder with this name already exists for this owner
# Use only the type/id format
stmt = text(
"""
SELECT id FROM folders
WHERE name = :name
AND owner->>'id' = :entity_id
AND owner->>'type' = :entity_type
"""
).bindparams(name=folder.name, entity_id=folder.owner["id"], entity_type=folder.owner["type"])
result = await session.execute(stmt)
existing_folder = result.scalar_one_or_none()
if existing_folder:
logger.info(
f"Folder '{folder.name}' already exists with ID {existing_folder}, not creating a duplicate"
)
# Update the provided folder's ID to match the existing one
# so the caller gets the correct ID
folder.id = existing_folder
return True
# Create a new folder model
access_control = folder_dict.get("access_control", {})
# Log access control to debug any issues
if "user_id" in access_control:
logger.info(f"Storing folder with user_id: {access_control['user_id']}")
else:
logger.info("No user_id found in folder access_control")
folder_model = FolderModel(
id=folder.id,
name=folder.name,
description=folder.description,
owner=folder_dict["owner"],
document_ids=folder_dict.get("document_ids", []),
system_metadata=folder_dict.get("system_metadata", {}),
access_control=access_control,
rules=folder_dict.get("rules", []),
)
session.add(folder_model)
await session.commit()
logger.info(f"Created new folder '{folder.name}' with ID {folder.id}")
return True
except Exception as e:
logger.error(f"Error creating folder: {e}")
return False
async def get_folder(self, folder_id: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by ID."""
try:
async with self.async_session() as session:
# Get the folder
logger.info(f"Getting folder with ID: {folder_id}")
result = await session.execute(select(FolderModel).where(FolderModel.id == folder_id))
folder_model = result.scalar_one_or_none()
if not folder_model:
logger.error(f"Folder with ID {folder_id} not found in database")
return None
# Convert to Folder object
folder_dict = {
"id": folder_model.id,
"name": folder_model.name,
"description": folder_model.description,
"owner": folder_model.owner,
"document_ids": folder_model.document_ids,
"system_metadata": folder_model.system_metadata,
"access_control": folder_model.access_control,
"rules": folder_model.rules,
}
folder = Folder(**folder_dict)
# Check if the user has access to the folder
if not self._check_folder_access(folder, auth, "read"):
return None
return folder
except Exception as e:
logger.error(f"Error getting folder: {e}")
return None
async def get_folder_by_name(self, name: str, auth: AuthContext) -> Optional[Folder]:
"""Get a folder by name."""
try:
async with self.async_session() as session:
# First try to get a folder owned by this entity
if auth.entity_type and auth.entity_id:
stmt = text(
"""
SELECT * FROM folders
WHERE name = :name
AND (owner->>'id' = :entity_id)
AND (owner->>'type' = :entity_type)
"""
).bindparams(name=name, entity_id=auth.entity_id, entity_type=auth.entity_type.value)
result = await session.execute(stmt)
folder_row = result.fetchone()
if folder_row:
# Convert to Folder object
folder_dict = {
"id": folder_row.id,
"name": folder_row.name,
"description": folder_row.description,
"owner": folder_row.owner,
"document_ids": folder_row.document_ids,
"system_metadata": folder_row.system_metadata,
"access_control": folder_row.access_control,
"rules": folder_row.rules,
}
return Folder(**folder_dict)
# If not found, try to find any accessible folder with that name
stmt = text(
"""
SELECT * FROM folders
WHERE name = :name
AND (
(owner->>'id' = :entity_id AND owner->>'type' = :entity_type)
OR (access_control->'readers' ? :entity_id)
OR (access_control->'writers' ? :entity_id)
OR (access_control->'admins' ? :entity_id)
OR (access_control->'user_id' ? :user_id)
)
"""
).bindparams(
name=name,
entity_id=auth.entity_id,
entity_type=auth.entity_type.value,
user_id=auth.user_id if auth.user_id else "",
)
result = await session.execute(stmt)
folder_row = result.fetchone()
if folder_row:
# Convert to Folder object
folder_dict = {
"id": folder_row.id,
"name": folder_row.name,
"description": folder_row.description,
"owner": folder_row.owner,
"document_ids": folder_row.document_ids,
"system_metadata": folder_row.system_metadata,
"access_control": folder_row.access_control,
"rules": folder_row.rules,
}
return Folder(**folder_dict)
return None
except Exception as e:
logger.error(f"Error getting folder by name: {e}")
return None
async def list_folders(self, auth: AuthContext) -> List[Folder]:
"""List all folders the user has access to."""
try:
folders = []
async with self.async_session() as session:
# Get all folders
result = await session.execute(select(FolderModel))
folder_models = result.scalars().all()
for folder_model in folder_models:
# Convert to Folder object
folder_dict = {
"id": folder_model.id,
"name": folder_model.name,
"description": folder_model.description,
"owner": folder_model.owner,
"document_ids": folder_model.document_ids,
"system_metadata": folder_model.system_metadata,
"access_control": folder_model.access_control,
"rules": folder_model.rules,
}
folder = Folder(**folder_dict)
# Check if the user has access to the folder
if self._check_folder_access(folder, auth, "read"):
folders.append(folder)
return folders
except Exception as e:
logger.error(f"Error listing folders: {e}")
return []
async def add_document_to_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Add a document to a folder."""
try:
# First, check if the user has access to the folder
folder = await self.get_folder(folder_id, auth)
if not folder:
logger.error(f"Folder {folder_id} not found or user does not have access")
return False
# Check if user has write access to the folder
if not self._check_folder_access(folder, auth, "write"):
logger.error(f"User does not have write access to folder {folder_id}")
return False
# Check if the document exists and user has access
document = await self.get_document(document_id, auth)
if not document:
logger.error(f"Document {document_id} not found or user does not have access")
return False
# Check if the document is already in the folder
if document_id in folder.document_ids:
logger.info(f"Document {document_id} is already in folder {folder_id}")
return True
# Add the document to the folder
async with self.async_session() as session:
# Add document_id to document_ids array
new_document_ids = folder.document_ids + [document_id]
folder_model = await session.get(FolderModel, folder_id)
if not folder_model:
logger.error(f"Folder {folder_id} not found in database")
return False
folder_model.document_ids = new_document_ids
# Also update the document's system_metadata to include the folder_name
folder_name_json = json.dumps(folder.name)
stmt = text(
f"""
UPDATE documents
SET system_metadata = jsonb_set(system_metadata, '{{folder_name}}', '{folder_name_json}'::jsonb)
WHERE external_id = :document_id
"""
).bindparams(document_id=document_id)
await session.execute(stmt)
await session.commit()
logger.info(f"Added document {document_id} to folder {folder_id}")
return True
except Exception as e:
logger.error(f"Error adding document to folder: {e}")
return False
async def remove_document_from_folder(self, folder_id: str, document_id: str, auth: AuthContext) -> bool:
"""Remove a document from a folder."""
try:
# First, check if the user has access to the folder
folder = await self.get_folder(folder_id, auth)
if not folder:
logger.error(f"Folder {folder_id} not found or user does not have access")
return False
# Check if user has write access to the folder
if not self._check_folder_access(folder, auth, "write"):
logger.error(f"User does not have write access to folder {folder_id}")
return False
# Check if the document is in the folder
if document_id not in folder.document_ids:
logger.warning(f"Tried to delete document {document_id} not in folder {folder_id}")
return True
# Remove the document from the folder
async with self.async_session() as session:
# Remove document_id from document_ids array
new_document_ids = [doc_id for doc_id in folder.document_ids if doc_id != document_id]
folder_model = await session.get(FolderModel, folder_id)
if not folder_model:
logger.error(f"Folder {folder_id} not found in database")
return False
folder_model.document_ids = new_document_ids
# Also update the document's system_metadata to remove the folder_name
stmt = text(
"""
UPDATE documents
SET system_metadata = jsonb_set(system_metadata, '{folder_name}', 'null'::jsonb)
WHERE external_id = :document_id
"""
).bindparams(document_id=document_id)
await session.execute(stmt)
await session.commit()
logger.info(f"Removed document {document_id} from folder {folder_id}")
return True
except Exception as e:
logger.error(f"Error removing document from folder: {e}")
return False
def _check_folder_access(self, folder: Folder, auth: AuthContext, permission: str = "read") -> bool:
"""Check if the user has the required permission for the folder."""
# Admin always has access
if "admin" in auth.permissions:
return True
# Check if folder is owned by the user
if (
auth.entity_type
and auth.entity_id
and folder.owner.get("type") == auth.entity_type.value
and folder.owner.get("id") == auth.entity_id
):
# In cloud mode, also verify user_id if present
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
folder_user_ids = folder.access_control.get("user_id", [])
if auth.user_id not in folder_user_ids:
return False
return True
# Check access control lists
access_control = folder.access_control or {}
if permission == "read":
readers = access_control.get("readers", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in readers:
return True
if permission == "write":
writers = access_control.get("writers", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in writers:
return True
# For admin permission, check admins list
if permission == "admin":
admins = access_control.get("admins", [])
if f"{auth.entity_type.value}:{auth.entity_id}" in admins:
return True
return False