Add folders and user scopes (#82)

This commit is contained in:
Adityavardhan Agrawal 2025-04-13 14:52:26 -07:00 committed by GitHub
parent 1f3df392da
commit 75556c924a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 4417 additions and 1583 deletions

View File

@ -1,6 +1,5 @@
JWT_SECRET_KEY="..." # Required in production, optional in dev mode (dev_mode=true in morphik.toml)
POSTGRES_URI="postgresql+asyncpg://postgres:postgres@localhost:5432/morphik" # Required for PostgreSQL database
MONGODB_URI="..." # Optional: Only needed if using MongoDB
UNSTRUCTURED_API_KEY="..." # Optional: Needed for parsing via unstructured API
OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions

View File

@ -46,7 +46,7 @@ Built for scale and performance, Morphik can handle millions of documents while
- 🧩 **Extensible Architecture**
- Support for custom parsers and embedding models
- Multiple storage backends (S3, local)
- Vector store integrations (PostgreSQL/pgvector, MongoDB)
- Vector store integration with PostgreSQL/pgvector
## Quick Start
@ -162,7 +162,7 @@ for chunk in chunks:
| **Knowledge Graphs** | ✅ Automated extraction & enhanced retrieval | ❌ | ❌ | ❌ |
| **Rules Engine** | ✅ Natural language rules & schema definition | ❌ | ❌ | Limited |
| **Caching** | ✅ Persistent KV-caching with selective updates | ❌ | ❌ | Limited |
| **Scalability** | ✅ Millions of documents with PostgreSQL/MongoDB | ✅ | ✅ | Limited |
| **Scalability** | ✅ Millions of documents with PostgreSQL | ✅ | ✅ | Limited |
| **Video Content** | ✅ Native video parsing & transcription | ❌ | ❌ | ❌ |
| **Deployment Options** | ✅ Self-hosted, cloud, or hybrid | Varies | Varies | Limited |
| **Open Source** | ✅ MIT License | Varies | Varies | Varies |
@ -176,7 +176,7 @@ for chunk in chunks:
- **Schema-like Rules for Unstructured Data**: Define rules to extract consistent metadata from unstructured content, bringing database-like queryability to any document format.
- **Enterprise-grade Scalability**: Built on proven database technologies (PostgreSQL/MongoDB) that can scale to millions of documents while maintaining sub-second retrieval times.
- **Enterprise-grade Scalability**: Built on proven PostgreSQL database technology that can scale to millions of documents while maintaining sub-second retrieval times.
## Documentation

View File

@ -20,9 +20,7 @@ from core.parser.morphik_parser import MorphikParser
from core.services.document_service import DocumentService
from core.services.telemetry import TelemetryService
from core.config import get_settings
from core.database.mongo_database import MongoDatabase
from core.database.postgres_database import PostgresDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.vector_store.multi_vector_store import MultiVectorStore
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.storage.s3_storage import S3Storage
@ -77,21 +75,9 @@ app.add_middleware(
settings = get_settings()
# Initialize database
match settings.DATABASE_PROVIDER:
case "postgres":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
case "mongodb":
if not settings.MONGODB_URI:
raise ValueError("MongoDB URI is required for MongoDB database")
database = MongoDatabase(
uri=settings.MONGODB_URI,
db_name=settings.DATABRIDGE_DB,
collection_name=settings.DOCUMENTS_COLLECTION,
)
case _:
raise ValueError(f"Unsupported database provider: {settings.DATABASE_PROVIDER}")
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
database = PostgresDatabase(uri=settings.POSTGRES_URI)
@app.on_event("startup")
@ -144,24 +130,13 @@ async def initialize_user_limits_database():
await user_limits_db.initialize()
# Initialize vector store
match settings.VECTOR_STORE_PROVIDER:
case "mongodb":
vector_store = MongoDBAtlasVectorStore(
uri=settings.MONGODB_URI,
database_name=settings.DATABRIDGE_DB,
collection_name=settings.CHUNKS_COLLECTION,
index_name=settings.VECTOR_INDEX_NAME,
)
case "pgvector":
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
from core.vector_store.pgvector_store import PGVectorStore
if not settings.POSTGRES_URI:
raise ValueError("PostgreSQL URI is required for pgvector store")
from core.vector_store.pgvector_store import PGVectorStore
vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
case _:
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
vector_store = PGVectorStore(
uri=settings.POSTGRES_URI,
)
# Initialize storage
match settings.STORAGE_PROVIDER:
@ -310,6 +285,8 @@ async def ingest_text(
- rules: Optional list of rules. Each rule should be either:
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
- folder_name: Optional folder to scope the document to
- end_user_id: Optional end-user ID to scope the document to
auth: Authentication context
Returns:
@ -324,6 +301,8 @@ async def ingest_text(
"metadata": request.metadata,
"rules": request.rules,
"use_colpali": request.use_colpali,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
return await document_service.ingest_text(
@ -333,6 +312,8 @@ async def ingest_text(
rules=request.rules,
use_colpali=request.use_colpali,
auth=auth,
folder_name=request.folder_name,
end_user_id=request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@ -345,6 +326,8 @@ async def ingest_file(
rules: str = Form("[]"),
auth: AuthContext = Depends(verify_token),
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
) -> Document:
"""
Ingest a file document.
@ -356,6 +339,9 @@ async def ingest_file(
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
auth: Authentication context
use_colpali: Whether to use ColPali embedding model
folder_name: Optional folder to scope the document to
end_user_id: Optional end-user ID to scope the document to
Returns:
Document: Metadata of ingested document
@ -374,15 +360,20 @@ async def ingest_file(
"metadata": metadata_dict,
"rules": rules_list,
"use_colpali": use_colpali,
"folder_name": folder_name,
"end_user_id": end_user_id,
},
):
logger.debug(f"API: Ingesting file with use_colpali: {use_colpali}")
return await document_service.ingest_file(
file=file,
metadata=metadata_dict,
auth=auth,
rules=rules_list,
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id,
)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
@ -397,6 +388,8 @@ async def batch_ingest_files(
rules: str = Form("[]"),
use_colpali: Optional[bool] = Form(None),
parallel: bool = Form(True),
folder_name: Optional[str] = Form(None),
end_user_id: Optional[str] = Form(None),
auth: AuthContext = Depends(verify_token),
) -> BatchIngestResponse:
"""
@ -410,6 +403,8 @@ async def batch_ingest_files(
- A list of rule lists, one per file
use_colpali: Whether to use ColPali-style embedding
parallel: Whether to process files in parallel
folder_name: Optional folder to scope the documents to
end_user_id: Optional end-user ID to scope the documents to
auth: Authentication context
Returns:
@ -447,6 +442,8 @@ async def batch_ingest_files(
documents = []
errors = []
# We'll pass folder_name and end_user_id directly to the ingest_file functions
async with telemetry.track_operation(
operation_type="batch_ingest",
user_id=auth.entity_id,
@ -454,6 +451,8 @@ async def batch_ingest_files(
"file_count": len(files),
"metadata_type": "list" if isinstance(metadata_value, list) else "single",
"rules_type": "per_file" if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else "shared",
"folder_name": folder_name,
"end_user_id": end_user_id,
},
):
if parallel:
@ -466,7 +465,9 @@ async def batch_ingest_files(
metadata=metadata_item,
auth=auth,
rules=file_rules,
use_colpali=use_colpali
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id
)
tasks.append(task)
@ -490,7 +491,9 @@ async def batch_ingest_files(
metadata=metadata_item,
auth=auth,
rules=file_rules,
use_colpali=use_colpali
use_colpali=use_colpali,
folder_name=folder_name,
end_user_id=end_user_id
)
documents.append(doc)
except Exception as e:
@ -504,7 +507,24 @@ async def batch_ingest_files(
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant chunks."""
"""
Retrieve relevant chunks.
Args:
request: RetrieveRequest containing:
- query: Search query text
- filters: Optional metadata filters
- k: Number of results (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- folder_name: Optional folder to scope the search to
- end_user_id: Optional end-user ID to scope the search to
auth: Authentication context
Returns:
List[ChunkResult]: List of relevant chunks
"""
try:
async with telemetry.track_operation(
operation_type="retrieve_chunks",
@ -514,6 +534,8 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
"min_score": request.min_score,
"use_reranking": request.use_reranking,
"use_colpali": request.use_colpali,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
return await document_service.retrieve_chunks(
@ -524,6 +546,8 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
request.min_score,
request.use_reranking,
request.use_colpali,
request.folder_name,
request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@ -531,7 +555,24 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant documents."""
"""
Retrieve relevant documents.
Args:
request: RetrieveRequest containing:
- query: Search query text
- filters: Optional metadata filters
- k: Number of results (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- folder_name: Optional folder to scope the search to
- end_user_id: Optional end-user ID to scope the search to
auth: Authentication context
Returns:
List[DocumentResult]: List of relevant documents
"""
try:
async with telemetry.track_operation(
operation_type="retrieve_docs",
@ -541,6 +582,8 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
"min_score": request.min_score,
"use_reranking": request.use_reranking,
"use_colpali": request.use_colpali,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
return await document_service.retrieve_docs(
@ -551,39 +594,99 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
request.min_score,
request.use_reranking,
request.use_colpali,
request.folder_name,
request.end_user_id,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/documents", response_model=List[Document])
async def batch_get_documents(document_ids: List[str], auth: AuthContext = Depends(verify_token)):
"""Retrieve multiple documents by their IDs in a single batch operation."""
async def batch_get_documents(
request: Dict[str, Any],
auth: AuthContext = Depends(verify_token)
):
"""
Retrieve multiple documents by their IDs in a single batch operation.
Args:
request: Dictionary containing:
- document_ids: List of document IDs to retrieve
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
List[Document]: List of documents matching the IDs
"""
try:
# Extract document_ids from request
document_ids = request.get("document_ids", [])
folder_name = request.get("folder_name")
end_user_id = request.get("end_user_id")
if not document_ids:
return []
async with telemetry.track_operation(
operation_type="batch_get_documents",
user_id=auth.entity_id,
metadata={
"document_count": len(document_ids),
"folder_name": folder_name,
"end_user_id": end_user_id,
},
):
return await document_service.batch_retrieve_documents(document_ids, auth)
return await document_service.batch_retrieve_documents(document_ids, auth, folder_name, end_user_id)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/batch/chunks", response_model=List[ChunkResult])
async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Depends(verify_token)):
"""Retrieve specific chunks by their document ID and chunk number in a single batch operation."""
async def batch_get_chunks(
request: Dict[str, Any],
auth: AuthContext = Depends(verify_token)
):
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Args:
request: Dictionary containing:
- sources: List of ChunkSource objects (with document_id and chunk_number)
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
List[ChunkResult]: List of chunk results
"""
try:
# Extract sources from request
sources = request.get("sources", [])
folder_name = request.get("folder_name")
end_user_id = request.get("end_user_id")
if not sources:
return []
async with telemetry.track_operation(
operation_type="batch_get_chunks",
user_id=auth.entity_id,
metadata={
"chunk_count": len(chunk_ids),
"chunk_count": len(sources),
"folder_name": folder_name,
"end_user_id": end_user_id,
},
):
return await document_service.batch_retrieve_chunks(chunk_ids, auth)
# Convert sources to ChunkSource objects if needed
chunk_sources = []
for source in sources:
if isinstance(source, dict):
chunk_sources.append(ChunkSource(**source))
else:
chunk_sources.append(source)
return await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name, end_user_id)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@ -592,10 +695,32 @@ async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Dep
async def query_completion(
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
):
"""Generate completion using relevant chunks as context.
"""
Generate completion using relevant chunks as context.
When graph_name is provided, the query will leverage the knowledge graph
to enhance retrieval by finding relevant entities and their connected documents.
Args:
request: CompletionQueryRequest containing:
- query: Query text
- filters: Optional metadata filters
- k: Number of chunks to use as context (default: 4)
- min_score: Minimum similarity threshold (default: 0.0)
- max_tokens: Maximum tokens in completion
- temperature: Model temperature
- use_reranking: Whether to use reranking
- use_colpali: Whether to use ColPali-style embedding model
- graph_name: Optional name of the graph to use for knowledge graph-enhanced retrieval
- hop_depth: Number of relationship hops to traverse in the graph (1-3)
- include_paths: Whether to include relationship paths in the response
- prompt_overrides: Optional customizations for entity extraction, resolution, and query prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
CompletionResponse: Generated completion
"""
try:
# Validate prompt overrides before proceeding
@ -620,6 +745,8 @@ async def query_completion(
"graph_name": request.graph_name,
"hop_depth": request.hop_depth,
"include_paths": request.include_paths,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
return await document_service.query(
@ -636,6 +763,8 @@ async def query_completion(
request.hop_depth,
request.include_paths,
request.prompt_overrides,
request.folder_name,
request.end_user_id,
)
except ValueError as e:
validate_prompt_overrides_with_http_exception(operation_type="query", error=e)
@ -649,9 +778,31 @@ async def list_documents(
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
):
"""List accessible documents."""
return await document_service.db.get_documents(auth, skip, limit, filters)
"""
List accessible documents.
Args:
auth: Authentication context
skip: Number of documents to skip
limit: Maximum number of documents to return
filters: Optional metadata filters
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
List[Document]: List of accessible documents
"""
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
return await document_service.db.get_documents(auth, skip, limit, filters, system_filters)
@app.get("/documents/{document_id}", response_model=Document)
@ -700,10 +851,33 @@ async def delete_document(document_id: str, auth: AuthContext = Depends(verify_t
@app.get("/documents/filename/{filename}", response_model=Document)
async def get_document_by_filename(filename: str, auth: AuthContext = Depends(verify_token)):
"""Get document by filename."""
async def get_document_by_filename(
filename: str,
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
):
"""
Get document by filename.
Args:
filename: Filename of the document to retrieve
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
Document: Document metadata if found and accessible
"""
try:
doc = await document_service.db.get_document_by_filename(filename, auth)
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
doc = await document_service.db.get_document_by_filename(filename, auth, system_filters)
logger.debug(f"Found document by filename: {doc}")
if not doc:
raise HTTPException(status_code=404, detail=f"Document with filename '{filename}' not found")
@ -1071,6 +1245,9 @@ async def create_graph(
- name: Name of the graph to create
- filters: Optional metadata filters to determine which documents to include
- documents: Optional list of specific document IDs to include
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
@ -1093,14 +1270,24 @@ async def create_graph(
"name": request.name,
"filters": request.filters,
"documents": request.documents,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
# Create system filters for folder and user scoping
system_filters = {}
if request.folder_name:
system_filters["folder_name"] = request.folder_name
if request.end_user_id:
system_filters["end_user_id"] = request.end_user_id
return await document_service.create_graph(
name=request.name,
auth=auth,
filters=request.filters,
documents=request.documents,
prompt_overrides=request.prompt_overrides,
system_filters=system_filters,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@ -1112,6 +1299,8 @@ async def create_graph(
async def get_graph(
name: str,
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> Graph:
"""
Get a graph by name.
@ -1121,6 +1310,8 @@ async def get_graph(
Args:
name: Name of the graph to retrieve
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
Graph: The requested graph object
@ -1129,9 +1320,20 @@ async def get_graph(
async with telemetry.track_operation(
operation_type="get_graph",
user_id=auth.entity_id,
metadata={"name": name},
metadata={
"name": name,
"folder_name": folder_name,
"end_user_id": end_user_id
},
):
graph = await document_service.db.get_graph(name, auth)
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
graph = await document_service.db.get_graph(name, auth, system_filters)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph '{name}' not found")
return graph
@ -1144,6 +1346,8 @@ async def get_graph(
@app.get("/graphs", response_model=List[Graph])
async def list_graphs(
auth: AuthContext = Depends(verify_token),
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[Graph]:
"""
List all graphs the user has access to.
@ -1152,6 +1356,8 @@ async def list_graphs(
Args:
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
Returns:
List[Graph]: List of graph objects
@ -1160,8 +1366,19 @@ async def list_graphs(
async with telemetry.track_operation(
operation_type="list_graphs",
user_id=auth.entity_id,
metadata={
"folder_name": folder_name,
"end_user_id": end_user_id
},
):
return await document_service.db.list_graphs(auth)
# Create system filters for folder and user scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
return await document_service.db.list_graphs(auth, system_filters)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
@ -1186,6 +1403,9 @@ async def update_graph(
request: UpdateGraphRequest containing:
- additional_filters: Optional additional metadata filters to determine which new documents to include
- additional_documents: Optional list of additional document IDs to include
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
- folder_name: Optional folder to scope the operation to
- end_user_id: Optional end-user ID to scope the operation to
auth: Authentication context
Returns:
@ -1203,14 +1423,24 @@ async def update_graph(
"name": name,
"additional_filters": request.additional_filters,
"additional_documents": request.additional_documents,
"folder_name": request.folder_name,
"end_user_id": request.end_user_id,
},
):
# Create system filters for folder and user scoping
system_filters = {}
if request.folder_name:
system_filters["folder_name"] = request.folder_name
if request.end_user_id:
system_filters["end_user_id"] = request.end_user_id
return await document_service.update_graph(
name=name,
auth=auth,
additional_filters=request.additional_filters,
additional_documents=request.additional_documents,
prompt_overrides=request.prompt_overrides,
system_filters=system_filters,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))

View File

@ -13,7 +13,6 @@ class Settings(BaseSettings):
# Environment variables
JWT_SECRET_KEY: str
POSTGRES_URI: Optional[str] = None
MONGODB_URI: Optional[str] = None
UNSTRUCTURED_API_KEY: Optional[str] = None
AWS_ACCESS_KEY: Optional[str] = None
AWS_SECRET_ACCESS_KEY: Optional[str] = None
@ -42,9 +41,8 @@ class Settings(BaseSettings):
# Database configuration
DATABASE_PROVIDER: Literal["postgres", "mongodb"]
DATABASE_PROVIDER: Literal["postgres"]
DATABASE_NAME: Optional[str] = None
DOCUMENTS_COLLECTION: Optional[str] = None
# Embedding configuration
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
@ -85,9 +83,8 @@ class Settings(BaseSettings):
S3_BUCKET: Optional[str] = None
# Vector store configuration
VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"]
VECTOR_STORE_PROVIDER: Literal["pgvector"]
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
VECTOR_STORE_COLLECTION_NAME: Optional[str] = None
# Colpali configuration
ENABLE_COLPALI: bool
@ -164,24 +161,17 @@ def get_settings() -> Settings:
# load database config
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]}
match database_config["DATABASE_PROVIDER"]:
case "mongodb":
database_config.update(
{
"DATABASE_NAME": config["database"]["database_name"],
"COLLECTION_NAME": config["database"]["collection_name"],
}
)
case "postgres" if "POSTGRES_URI" in os.environ:
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
case "postgres":
msg = em.format(
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
)
raise ValueError(msg)
case _:
prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'")
if database_config["DATABASE_PROVIDER"] != "postgres":
prov = database_config["DATABASE_PROVIDER"]
raise ValueError(f"Unknown database provider selected: '{prov}'")
if "POSTGRES_URI" in os.environ:
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
else:
msg = em.format(
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
)
raise ValueError(msg)
# load embedding config
embedding_config = {
@ -251,23 +241,15 @@ def get_settings() -> Settings:
# load vector store config
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
match vector_store_config["VECTOR_STORE_PROVIDER"]:
case "mongodb":
vector_store_config.update(
{
"VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"],
"VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"],
}
)
case "pgvector":
if "POSTGRES_URI" not in os.environ:
msg = em.format(
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
)
raise ValueError(msg)
case _:
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
if vector_store_config["VECTOR_STORE_PROVIDER"] != "pgvector":
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
if "POSTGRES_URI" not in os.environ:
msg = em.format(
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
)
raise ValueError(msg)
# load rules config
rules_config = {

View File

@ -26,7 +26,7 @@ class BaseDatabase(ABC):
pass
@abstractmethod
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
async def get_document_by_filename(self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Document]:
"""
Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
@ -34,6 +34,7 @@ class BaseDatabase(ABC):
Args:
filename: The filename to search for
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Document if found and accessible, None otherwise
@ -41,14 +42,16 @@ class BaseDatabase(ABC):
pass
@abstractmethod
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Can filter by system metadata fields like folder_name and end_user_id.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
system_filters: Optional filters for system metadata fields
Returns:
List of Document objects that were found and user has access to
@ -62,10 +65,21 @@ class BaseDatabase(ABC):
skip: int = 0,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""
List documents the user has access to.
Supports pagination and filtering.
Args:
auth: Authentication context
skip: Number of documents to skip (for pagination)
limit: Maximum number of documents to return
filters: Optional metadata filters
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List of documents matching the criteria
"""
pass
@ -89,9 +103,18 @@ class BaseDatabase(ABC):
@abstractmethod
async def find_authorized_and_filtered_documents(
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Find document IDs matching filters that user has access to."""
"""Find document IDs matching filters that user has access to.
Args:
auth: Authentication context
filters: Optional metadata filters
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List of document IDs matching the criteria
"""
pass
@abstractmethod
@ -142,12 +165,13 @@ class BaseDatabase(ABC):
pass
@abstractmethod
async def get_graph(self, name: str, auth: AuthContext) -> Optional[Graph]:
async def get_graph(self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
@ -155,11 +179,12 @@ class BaseDatabase(ABC):
pass
@abstractmethod
async def list_graphs(self, auth: AuthContext) -> List[Graph]:
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List[Graph]: List of graphs

View File

@ -1,329 +0,0 @@
from datetime import UTC, datetime
import logging
from typing import Dict, List, Optional, Any
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ReturnDocument
from pymongo.errors import PyMongoError
from .base_database import BaseDatabase
from ..models.documents import Document
from ..models.auth import AuthContext, EntityType
logger = logging.getLogger(__name__)
class MongoDatabase(BaseDatabase):
"""MongoDB implementation for document metadata storage."""
def __init__(
self,
uri: str,
db_name: str,
collection_name: str,
):
"""Initialize MongoDB connection for document storage."""
self.client = AsyncIOMotorClient(uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.caches = self.db["caches"] # Collection for cache metadata
async def initialize(self):
"""Initialize database indexes."""
try:
# Create indexes for common queries
await self.collection.create_index("external_id", unique=True)
await self.collection.create_index("owner.id")
await self.collection.create_index("access_control.readers")
await self.collection.create_index("access_control.writers")
await self.collection.create_index("access_control.admins")
await self.collection.create_index("system_metadata.created_at")
logger.info("MongoDB indexes created successfully")
return True
except PyMongoError as e:
logger.error(f"Error creating MongoDB indexes: {str(e)}")
return False
async def store_document(self, document: Document) -> bool:
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Ensure system metadata
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
doc_dict["metadata"]["external_id"] = doc_dict["external_id"]
result = await self.collection.insert_one(doc_dict)
return bool(result.inserted_id)
except PyMongoError as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = {"$and": [{"external_id": document_id}, access_filter]}
logger.debug(f"Querying document with query: {query}")
doc_dict = await self.collection.find_one(query)
logger.debug(f"Found document: {doc_dict}")
return Document(**doc_dict) if doc_dict else None
except PyMongoError as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
raise e
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
"""
try:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = {"$and": [{"filename": filename}, access_filter]}
logger.debug(f"Querying document by filename with query: {query}")
# Sort by updated_at in descending order to get the most recent one
sort_criteria = [("system_metadata.updated_at", -1)]
doc_dict = await self.collection.find_one(query, sort=sort_criteria)
logger.debug(f"Found document by filename: {doc_dict}")
return Document(**doc_dict) if doc_dict else None
except PyMongoError as e:
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
raise e
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
Returns:
List of Document objects that were found and user has access to
"""
try:
if not document_ids:
return []
# Build access filter
access_filter = self._build_access_filter(auth)
# Query documents with both document IDs and access check in a single query
query = {
"$and": [
{"external_id": {"$in": document_ids}},
access_filter
]
}
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
# Execute batch query
cursor = self.collection.find(query)
documents = []
async for doc_dict in cursor:
documents.append(Document(**doc_dict))
logger.info(f"Found {len(documents)} documents in batch retrieval")
return documents
except PyMongoError as e:
logger.error(f"Error batch retrieving documents: {str(e)}")
return []
async def get_documents(
self,
auth: AuthContext,
skip: int = 0,
limit: int = 100,
filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List accessible documents with pagination and filtering."""
try:
# Build query
auth_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
# Execute paginated query
cursor = self.collection.find(query).skip(skip).limit(limit)
documents = []
async for doc_dict in cursor:
documents.append(Document(**doc_dict))
return documents
except PyMongoError as e:
logger.error(f"Error listing documents: {str(e)}")
return []
async def update_document(
self, document_id: str, updates: Dict[str, Any], auth: AuthContext
) -> bool:
"""Update document metadata if user has write access."""
try:
# Verify write access
if not await self.check_access(document_id, auth, "write"):
return False
# Update system metadata
updates.setdefault("system_metadata", {})
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
result = await self.collection.find_one_and_update(
{"external_id": document_id},
{"$set": updates},
return_document=ReturnDocument.AFTER,
)
return bool(result)
except PyMongoError as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
"""Delete document if user has admin access."""
try:
# Verify admin access
if not await self.check_access(document_id, auth, "admin"):
return False
result = await self.collection.delete_one({"external_id": document_id})
return bool(result.deleted_count)
except PyMongoError as e:
logger.error(f"Error deleting document: {str(e)}")
return False
async def find_authorized_and_filtered_documents(
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
# Build query
auth_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
# Get matching document IDs
cursor = self.collection.find(query, {"external_id": 1})
document_ids = []
async for doc in cursor:
document_ids.append(doc["external_id"])
return document_ids
async def check_access(
self, document_id: str, auth: AuthContext, required_permission: str = "read"
) -> bool:
"""Check if user has required permission for document."""
try:
doc = await self.collection.find_one({"external_id": document_id})
if not doc:
return False
access_control = doc.get("access_control", {})
# Check owner access
owner = doc.get("owner", {})
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, set())
except PyMongoError as e:
logger.error(f"Error checking document access: {str(e)}")
return False
def _build_access_filter(self, auth: AuthContext) -> Dict[str, Any]:
"""Build MongoDB filter for access control."""
base_filter = {
"$or": [
{"owner.id": auth.entity_id},
{"access_control.readers": auth.entity_id},
{"access_control.writers": auth.entity_id},
{"access_control.admins": auth.entity_id},
]
}
if auth.entity_type == EntityType.DEVELOPER:
# Add app-specific access for developers
base_filter["$or"].append({"access_control.app_access": auth.app_id})
return base_filter
def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Build MongoDB filter for metadata."""
if not filters:
return {}
filter_dict = {}
for key, value in filters.items():
filter_dict[f"metadata.{key}"] = value
return filter_dict
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in MongoDB.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
# Add timestamp and ensure name is included
doc = {
"name": name,
"metadata": metadata,
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
}
# Upsert the document
result = await self.caches.update_one({"name": name}, {"$set": doc}, upsert=True)
return bool(result.modified_count or result.upserted_id)
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from MongoDB.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
doc = await self.caches.find_one({"name": name})
return doc["metadata"] if doc else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None

View File

@ -51,6 +51,7 @@ class GraphModel(Base):
entities = Column(JSONB, default=list)
relationships = Column(JSONB, default=list)
graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict
system_metadata = Column(JSONB, default=dict) # For folder_name and end_user_id
document_ids = Column(JSONB, default=list)
filters = Column(JSONB, nullable=True)
created_at = Column(String) # ISO format string
@ -63,6 +64,7 @@ class GraphModel(Base):
Index("idx_graph_name", "name"),
Index("idx_graph_owner", "owner", postgresql_using="gin"),
Index("idx_graph_access_control", "access_control", postgresql_using="gin"),
Index("idx_graph_system_metadata", "system_metadata", postgresql_using="gin"),
)
@ -139,6 +141,68 @@ class PostgresDatabase(BaseDatabase):
)
)
logger.info("Added storage_files column to documents table")
# Create indexes for folder_name and end_user_id in system_metadata for documents
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_folder_name
ON documents ((system_metadata->>'folder_name'));
"""
)
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_system_metadata_end_user_id
ON documents ((system_metadata->>'end_user_id'));
"""
)
)
# Check if system_metadata column exists in graphs table
result = await conn.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'graphs' AND column_name = 'system_metadata'
"""
)
)
if not result.first():
# Add system_metadata column to graphs table
await conn.execute(
text(
"""
ALTER TABLE graphs
ADD COLUMN IF NOT EXISTS system_metadata JSONB DEFAULT '{}'::jsonb
"""
)
)
logger.info("Added system_metadata column to graphs table")
# Create indexes for folder_name and end_user_id in system_metadata for graphs
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_folder_name
ON graphs ((system_metadata->>'folder_name'));
"""
)
)
await conn.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_end_user_id
ON graphs ((system_metadata->>'end_user_id'));
"""
)
)
logger.info("Created indexes for folder_name and end_user_id in system_metadata")
logger.info("PostgreSQL tables and indexes created successfully")
self._initialized = True
@ -221,24 +285,42 @@ class PostgresDatabase(BaseDatabase):
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
async def get_document_by_filename(self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Document]:
"""Retrieve document metadata by filename if user has access.
If multiple documents have the same filename, returns the most recently updated one.
Args:
filename: The filename to search for
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
"""
try:
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
# Query document
# Construct where clauses
where_clauses = [
f"({access_filter})",
f"filename = '{filename.replace('\'', '\'\'')}'" # Escape single quotes
]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query document with system filters
query = (
select(DocumentModel)
.where(DocumentModel.filename == filename)
.where(text(f"({access_filter})"))
.where(text(final_where_clause))
# Order by updated_at in system_metadata to get the most recent document
.order_by(text("system_metadata->>'updated_at' DESC"))
)
logger.debug(f"Querying document by filename with system filters: {system_filters}")
result = await session.execute(query)
doc_model = result.scalar_one_or_none()
@ -264,14 +346,16 @@ class PostgresDatabase(BaseDatabase):
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
return None
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
Only returns documents the user has access to.
Can filter by system metadata fields like folder_name and end_user_id.
Args:
document_ids: List of document IDs to retrieve
auth: Authentication context
system_filters: Optional filters for system metadata fields
Returns:
List of Document objects that were found and user has access to
@ -283,13 +367,21 @@ class PostgresDatabase(BaseDatabase):
async with self.async_session() as session:
# Build access filter
access_filter = self._build_access_filter(auth)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
# Query documents with both document IDs and access check in a single query
query = (
select(DocumentModel)
.where(DocumentModel.external_id.in_(document_ids))
.where(text(f"({access_filter})"))
)
# Construct where clauses
where_clauses = [
f"({access_filter})",
f"external_id IN ({', '.join([f'\'{doc_id}\'' for doc_id in document_ids])})"
]
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
# Query documents with document IDs, access check, and system filters in a single query
query = select(DocumentModel).where(text(final_where_clause))
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
@ -328,6 +420,7 @@ class PostgresDatabase(BaseDatabase):
skip: int = 0,
limit: int = 10000,
filters: Optional[Dict[str, Any]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> List[Document]:
"""List documents the user has access to."""
try:
@ -335,10 +428,18 @@ class PostgresDatabase(BaseDatabase):
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
query = select(DocumentModel).where(text(f"({access_filter})"))
where_clauses = [f"({access_filter})"]
if metadata_filter:
query = query.where(text(metadata_filter))
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel).where(text(final_where_clause))
query = query.offset(skip).limit(limit)
@ -373,9 +474,23 @@ class PostgresDatabase(BaseDatabase):
try:
if not await self.check_access(document_id, auth, "write"):
return False
# Get existing document to preserve system_metadata
existing_doc = await self.get_document(document_id, auth)
if not existing_doc:
return False
# Update system metadata
updates.setdefault("system_metadata", {})
# Preserve folder_name and end_user_id if not explicitly overridden
if existing_doc.system_metadata:
if "folder_name" in existing_doc.system_metadata and "folder_name" not in updates["system_metadata"]:
updates["system_metadata"]["folder_name"] = existing_doc.system_metadata["folder_name"]
if "end_user_id" in existing_doc.system_metadata and "end_user_id" not in updates["system_metadata"]:
updates["system_metadata"]["end_user_id"] = existing_doc.system_metadata["end_user_id"]
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
# Serialize datetime objects to ISO format strings
@ -421,7 +536,7 @@ class PostgresDatabase(BaseDatabase):
return False
async def find_authorized_and_filtered_documents(
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None
) -> List[str]:
"""Find document IDs matching filters and access permissions."""
try:
@ -429,14 +544,24 @@ class PostgresDatabase(BaseDatabase):
# Build query
access_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
system_metadata_filter = self._build_system_metadata_filter(system_filters)
logger.debug(f"Access filter: {access_filter}")
logger.debug(f"Metadata filter: {metadata_filter}")
logger.debug(f"System metadata filter: {system_metadata_filter}")
logger.debug(f"Original filters: {filters}")
logger.debug(f"System filters: {system_filters}")
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
where_clauses = [f"({access_filter})"]
if metadata_filter:
query = query.where(text(metadata_filter))
where_clauses.append(f"({metadata_filter})")
if system_metadata_filter:
where_clauses.append(f"({system_metadata_filter})")
final_where_clause = " AND ".join(where_clauses)
query = select(DocumentModel.external_id).where(text(final_where_clause))
logger.debug(f"Final query: {query}")
@ -525,6 +650,25 @@ class PostgresDatabase(BaseDatabase):
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)
def _build_system_metadata_filter(self, system_filters: Optional[Dict[str, Any]]) -> str:
"""Build PostgreSQL filter for system metadata."""
if not system_filters:
return ""
conditions = []
for key, value in system_filters.items():
if value is None:
continue
if isinstance(value, str):
# Replace single quotes with double single quotes to escape them
escaped_value = value.replace("'", "''")
conditions.append(f"system_metadata->>'{key}' = '{escaped_value}'")
else:
conditions.append(f"system_metadata->>'{key}' = '{value}'")
return " AND ".join(conditions)
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in PostgreSQL.
@ -618,12 +762,13 @@ class PostgresDatabase(BaseDatabase):
logger.error(f"Error storing graph: {str(e)}")
return False
async def get_graph(self, name: str, auth: AuthContext) -> Optional[Graph]:
async def get_graph(self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Graph]:
"""Get a graph by name.
Args:
name: Name of the graph
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
Optional[Graph]: Graph if found and accessible, None otherwise
@ -637,7 +782,8 @@ class PostgresDatabase(BaseDatabase):
# Build access filter
access_filter = self._build_access_filter(auth)
# Query graph
# We need to check if the documents in the graph match the system filters
# First get the graph without system filters
query = (
select(GraphModel)
.where(GraphModel.name == name)
@ -648,6 +794,32 @@ class PostgresDatabase(BaseDatabase):
graph_model = result.scalar_one_or_none()
if graph_model:
# If system filters are provided, we need to filter the document_ids
document_ids = graph_model.document_ids
if system_filters and document_ids:
# Apply system_filters to document_ids
system_metadata_filter = self._build_system_metadata_filter(system_filters)
if system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# If no documents match system filters, return None
if not filtered_doc_ids:
return None
# Update document_ids with filtered results
document_ids = filtered_doc_ids
# Convert to Graph model
graph_dict = {
"id": graph_model.id,
@ -655,7 +827,8 @@ class PostgresDatabase(BaseDatabase):
"entities": graph_model.entities,
"relationships": graph_model.relationships,
"metadata": graph_model.graph_metadata, # Reference the renamed column
"document_ids": graph_model.document_ids,
"system_metadata": graph_model.system_metadata or {}, # Include system_metadata
"document_ids": document_ids, # Use possibly filtered document_ids
"filters": graph_model.filters,
"created_at": graph_model.created_at,
"updated_at": graph_model.updated_at,
@ -670,11 +843,12 @@ class PostgresDatabase(BaseDatabase):
logger.error(f"Error retrieving graph: {str(e)}")
return None
async def list_graphs(self, auth: AuthContext) -> List[Graph]:
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
"""List all graphs the user has access to.
Args:
auth: Authentication context
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
Returns:
List[Graph]: List of graphs
@ -693,23 +867,66 @@ class PostgresDatabase(BaseDatabase):
result = await session.execute(query)
graph_models = result.scalars().all()
return [
Graph(
id=graph.id,
name=graph.name,
entities=graph.entities,
relationships=graph.relationships,
metadata=graph.graph_metadata, # Reference the renamed column
document_ids=graph.document_ids,
filters=graph.filters,
created_at=graph.created_at,
updated_at=graph.updated_at,
owner=graph.owner,
access_control=graph.access_control,
)
for graph in graph_models
]
graphs = []
# If system filters are provided, we need to filter each graph's document_ids
if system_filters:
system_metadata_filter = self._build_system_metadata_filter(system_filters)
for graph_model in graph_models:
document_ids = graph_model.document_ids
if document_ids and system_metadata_filter:
# Get document IDs with system filters
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
filter_query = f"""
SELECT external_id FROM documents
WHERE external_id IN ({doc_id_placeholders})
AND ({system_metadata_filter})
"""
filter_result = await session.execute(text(filter_query))
filtered_doc_ids = [row[0] for row in filter_result.all()]
# Only include graphs that have documents matching the system filters
if filtered_doc_ids:
graph = Graph(
id=graph_model.id,
name=graph_model.name,
entities=graph_model.entities,
relationships=graph_model.relationships,
metadata=graph_model.graph_metadata, # Reference the renamed column
system_metadata=graph_model.system_metadata or {}, # Include system_metadata
document_ids=filtered_doc_ids, # Use filtered document_ids
filters=graph_model.filters,
created_at=graph_model.created_at,
updated_at=graph_model.updated_at,
owner=graph_model.owner,
access_control=graph_model.access_control,
)
graphs.append(graph)
else:
# No system filters, include all graphs
graphs = [
Graph(
id=graph.id,
name=graph.name,
entities=graph.entities,
relationships=graph.relationships,
metadata=graph.graph_metadata, # Reference the renamed column
system_metadata=graph.system_metadata or {}, # Include system_metadata
document_ids=graph.document_ids,
filters=graph.filters,
created_at=graph.created_at,
updated_at=graph.updated_at,
owner=graph.owner,
access_control=graph.access_control,
)
for graph in graph_models
]
return graphs
except Exception as e:
logger.error(f"Error listing graphs: {str(e)}")

View File

@ -28,3 +28,5 @@ class CompletionRequest(BaseModel):
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.7
prompt_template: Optional[str] = None
folder_name: Optional[str] = None
end_user_id: Optional[str] = None

View File

@ -27,7 +27,7 @@ class StorageFileInfo(BaseModel):
class Document(BaseModel):
"""Represents a document stored in MongoDB documents collection"""
"""Represents a document stored in the database documents collection"""
external_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
owner: Dict[str, str]
@ -44,6 +44,8 @@ class Document(BaseModel):
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
"version": 1,
"folder_name": None,
"end_user_id": None,
}
)
"""metadata such as creation date etc."""

View File

@ -50,6 +50,14 @@ class Graph(BaseModel):
entities: List[Entity] = Field(default_factory=list)
relationships: List[Relationship] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
system_metadata: Dict[str, Any] = Field(
default_factory=lambda: {
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
"folder_name": None,
"end_user_id": None,
}
)
document_ids: List[str] = Field(default_factory=list)
filters: Optional[Dict[str, Any]] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@ -23,6 +23,8 @@ class RetrieveRequest(BaseModel):
include_paths: Optional[bool] = Field(
False, description="Whether to include relationship paths in the response"
)
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
class CompletionQueryRequest(RetrieveRequest):
@ -44,6 +46,8 @@ class IngestTextRequest(BaseModel):
metadata: Dict[str, Any] = Field(default_factory=dict)
rules: List[Dict[str, Any]] = Field(default_factory=list)
use_colpali: Optional[bool] = None
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
class CreateGraphRequest(BaseModel):
@ -66,6 +70,8 @@ class CreateGraphRequest(BaseModel):
}
}}
)
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
class UpdateGraphRequest(BaseModel):
@ -81,6 +87,8 @@ class UpdateGraphRequest(BaseModel):
None,
description="Optional customizations for entity extraction and resolution prompts"
)
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
class BatchIngestResponse(BaseModel):

View File

@ -95,6 +95,8 @@ class DocumentService:
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[ChunkResult]:
"""Retrieve relevant chunks."""
settings = get_settings()
@ -106,7 +108,14 @@ class DocumentService:
logger.info("Generated query embedding")
# Find authorized documents
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters)
# Build system filters for folder_name and end_user_id
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters, system_filters)
if not doc_ids:
logger.info("No authorized documents found")
return []
@ -194,11 +203,13 @@ class DocumentService:
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[DocumentResult]:
"""Retrieve relevant documents."""
# Get chunks first
chunks = await self.retrieve_chunks(
query, auth, filters, k, min_score, use_reranking, use_colpali
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
)
# Convert to document results
results = await self._create_document_results(auth, chunks)
@ -209,7 +220,9 @@ class DocumentService:
async def batch_retrieve_documents(
self,
document_ids: List[str],
auth: AuthContext
auth: AuthContext,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None
) -> List[Document]:
"""
Retrieve multiple documents by their IDs in a single batch operation.
@ -224,15 +237,24 @@ class DocumentService:
if not document_ids:
return []
# Build system filters for folder_name and end_user_id
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
# Use the database's batch retrieval method
documents = await self.db.get_documents_by_id(document_ids, auth)
documents = await self.db.get_documents_by_id(document_ids, auth, system_filters)
logger.info(f"Batch retrieved {len(documents)} documents out of {len(document_ids)} requested")
return documents
async def batch_retrieve_chunks(
self,
chunk_ids: List[ChunkSource],
auth: AuthContext
auth: AuthContext,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None
) -> List[ChunkResult]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
@ -251,7 +273,7 @@ class DocumentService:
doc_ids = list({source.document_id for source in chunk_ids})
# Find authorized documents in a single query
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth)
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth, folder_name, end_user_id)
authorized_doc_ids = {doc.external_id for doc in authorized_docs}
# Filter sources to only include authorized documents
@ -292,6 +314,8 @@ class DocumentService:
hop_depth: int = 1,
include_paths: bool = False,
prompt_overrides: Optional["QueryPromptOverrides"] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> CompletionResponse:
"""Generate completion using relevant chunks as context.
@ -329,11 +353,13 @@ class DocumentService:
hop_depth=hop_depth,
include_paths=include_paths,
prompt_overrides=prompt_overrides,
folder_name=folder_name,
end_user_id=end_user_id
)
# Standard retrieval without graph
chunks = await self.retrieve_chunks(
query, auth, filters, k, min_score, use_reranking, use_colpali
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
)
documents = await self._create_document_results(auth, chunks)
@ -374,6 +400,8 @@ class DocumentService:
auth: AuthContext = None,
rules: Optional[List[str]] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> Document:
"""Ingest a text document."""
if "write" not in auth.permissions:
@ -396,6 +424,12 @@ class DocumentService:
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
},
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
logger.debug(f"Created text document record with ID {doc.external_id}")
if settings.MODE == "cloud" and auth.user_id:
@ -459,6 +493,8 @@ class DocumentService:
auth: AuthContext,
rules: Optional[List[str]] = None,
use_colpali: Optional[bool] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> Document:
"""Ingest a file document."""
if "write" not in auth.permissions:
@ -527,6 +563,12 @@ class DocumentService:
},
additional_metadata=additional_metadata,
)
# Add folder_name and end_user_id to system_metadata if provided
if folder_name:
doc.system_metadata["folder_name"] = folder_name
if end_user_id:
doc.system_metadata["end_user_id"] = end_user_id
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding with parsing
@ -730,7 +772,13 @@ class DocumentService:
chunks: List[Chunk],
embeddings: List[List[float]],
) -> List[DocumentChunk]:
"""Helper to create chunk objects"""
"""Helper to create chunk objects
Note: folder_name and end_user_id are not needed in chunk metadata because:
1. Filtering by these values happens at the document level in find_authorized_and_filtered_documents
2. Vector search is only performed on already authorized and filtered documents
3. This approach is more efficient as it reduces the size of chunk metadata
"""
return [
c.to_document_chunk(chunk_number=i, embedding=embedding, document_id=doc_id)
for i, (embedding, c) in enumerate(zip(embeddings, chunks))
@ -1341,6 +1389,7 @@ class DocumentService:
filters: Optional[Dict[str, Any]] = None,
documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Create a graph from documents.
@ -1352,6 +1401,8 @@ class DocumentService:
auth: Authentication context
filters: Optional metadata filters to determine which documents to include
documents: Optional list of specific document IDs to include
prompt_overrides: Optional customizations for entity extraction and resolution prompts
system_filters: Optional system filters like folder_name and end_user_id for scoping
Returns:
Graph: The created graph
@ -1364,6 +1415,7 @@ class DocumentService:
filters=filters,
documents=documents,
prompt_overrides=prompt_overrides,
system_filters=system_filters,
)
async def update_graph(
@ -1373,6 +1425,7 @@ class DocumentService:
additional_filters: Optional[Dict[str, Any]] = None,
additional_documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Update an existing graph with new documents.
@ -1384,6 +1437,8 @@ class DocumentService:
auth: Authentication context
additional_filters: Optional additional metadata filters to determine which new documents to include
additional_documents: Optional list of additional document IDs to include
prompt_overrides: Optional customizations for entity extraction and resolution prompts
system_filters: Optional system filters like folder_name and end_user_id for scoping
Returns:
Graph: The updated graph
@ -1396,6 +1451,7 @@ class DocumentService:
additional_filters=additional_filters,
additional_documents=additional_documents,
prompt_overrides=prompt_overrides,
system_filters=system_filters,
)
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:

View File

@ -68,6 +68,7 @@ class GraphService:
additional_filters: Optional[Dict[str, Any]] = None,
additional_documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Update an existing graph with new documents.
@ -81,10 +82,15 @@ class GraphService:
additional_filters: Optional additional metadata filters to determine which new documents to include
additional_documents: Optional list of specific additional document IDs to include
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
Returns:
Graph: The updated graph
"""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
@ -99,7 +105,7 @@ class GraphService:
# Find new documents to process
document_ids = await self._get_new_document_ids(
auth, existing_graph, additional_filters, additional_documents
auth, existing_graph, additional_filters, additional_documents, system_filters
)
if not document_ids and not explicit_doc_ids:
@ -123,7 +129,7 @@ class GraphService:
# Batch retrieve all documents in a single call
document_objects = await document_service.batch_retrieve_documents(
all_ids_to_retrieve, auth
all_ids_to_retrieve, auth, system_filters.get("folder_name", None), system_filters.get("end_user_id", None)
)
# Process explicit documents if needed
@ -150,6 +156,8 @@ class GraphService:
if doc_id not in {d.external_id for d in document_objects}
],
auth,
system_filters.get("folder_name", None),
system_filters.get("end_user_id", None)
)
logger.info(f"Additional filtered documents to include: {len(filtered_docs)}")
document_objects.extend(filtered_docs)
@ -190,23 +198,28 @@ class GraphService:
existing_graph: Graph,
additional_filters: Optional[Dict[str, Any]] = None,
additional_documents: Optional[List[str]] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Set[str]:
"""Get IDs of new documents to add to the graph."""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
# Initialize with explicitly specified documents, ensuring it's a set
document_ids = set(additional_documents or [])
# Process documents matching additional filters
if additional_filters:
filtered_docs = await self.db.get_documents(auth, filters=additional_filters)
if additional_filters or system_filters:
filtered_docs = await self.db.get_documents(auth, filters=additional_filters, system_filters=system_filters)
filter_doc_ids = {doc.external_id for doc in filtered_docs}
logger.info(f"Found {len(filter_doc_ids)} documents matching additional filters")
logger.info(f"Found {len(filter_doc_ids)} documents matching additional filters and system filters")
document_ids.update(filter_doc_ids)
# Process documents matching the original filters
if existing_graph.filters:
filtered_docs = await self.db.get_documents(auth, filters=existing_graph.filters)
# Original filters shouldn't include system filters, as we're applying them separately
filtered_docs = await self.db.get_documents(auth, filters=existing_graph.filters, system_filters=system_filters)
orig_filter_doc_ids = {doc.external_id for doc in filtered_docs}
logger.info(f"Found {len(orig_filter_doc_ids)} documents matching original filters")
logger.info(f"Found {len(orig_filter_doc_ids)} documents matching original filters and system filters")
document_ids.update(orig_filter_doc_ids)
# Get only the document IDs that are not already in the graph
@ -384,6 +397,7 @@ class GraphService:
filters: Optional[Dict[str, Any]] = None,
documents: Optional[List[str]] = None,
prompt_overrides: Optional[GraphPromptOverrides] = None,
system_filters: Optional[Dict[str, Any]] = None,
) -> Graph:
"""Create a graph from documents.
@ -397,10 +411,15 @@ class GraphService:
filters: Optional metadata filters to determine which documents to include
documents: Optional list of specific document IDs to include
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
Returns:
Graph: The created graph
"""
# Initialize system_filters if None
if system_filters is None:
system_filters = {}
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
@ -408,15 +427,28 @@ class GraphService:
document_ids = set(documents or [])
# If filters were provided, get matching documents
if filters:
filtered_docs = await self.db.get_documents(auth, filters=filters)
if filters or system_filters:
filtered_docs = await self.db.get_documents(auth, filters=filters, system_filters=system_filters)
document_ids.update(doc.external_id for doc in filtered_docs)
if not document_ids:
raise ValueError("No documents found matching criteria")
# Convert system_filters for document retrieval
folder_name = system_filters.get("folder_name") if system_filters else None
end_user_id = system_filters.get("end_user_id") if system_filters else None
# Batch retrieve documents for authorization check
document_objects = await document_service.batch_retrieve_documents(list(document_ids), auth)
document_objects = await document_service.batch_retrieve_documents(
list(document_ids),
auth,
folder_name,
end_user_id
)
# Log for debugging
logger.info(f"Graph creation with folder_name={folder_name}, end_user_id={end_user_id}")
logger.info(f"Documents retrieved: {len(document_objects)} out of {len(document_ids)} requested")
if not document_objects:
raise ValueError("No authorized documents found matching criteria")
@ -434,6 +466,13 @@ class GraphService:
"admins": [auth.entity_id],
},
)
# Add folder_name and end_user_id to system_metadata if provided
if system_filters:
if "folder_name" in system_filters:
graph.system_metadata["folder_name"] = system_filters["folder_name"]
if "end_user_id" in system_filters:
graph.system_metadata["end_user_id"] = system_filters["end_user_id"]
# Extract entities and relationships
entities, relationships = await self._process_documents_for_entities(
@ -868,6 +907,8 @@ class GraphService:
hop_depth: int = 1,
include_paths: bool = False,
prompt_overrides: Optional[QueryPromptOverrides] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> CompletionResponse:
"""Generate completion using knowledge graph-enhanced retrieval.
@ -899,8 +940,15 @@ class GraphService:
# Validation is now handled by type annotations
# Get the knowledge graph
graph = await self.db.get_graph(graph_name, auth)
# Build system filters for scoping
system_filters = {}
if folder_name:
system_filters["folder_name"] = folder_name
if end_user_id:
system_filters["end_user_id"] = end_user_id
logger.info(f"Querying graph with system_filters: {system_filters}")
graph = await self.db.get_graph(graph_name, auth, system_filters=system_filters)
if not graph:
logger.warning(f"Graph '{graph_name}' not found or not accessible")
# Fall back to standard retrieval if graph not found
@ -915,12 +963,14 @@ class GraphService:
use_reranking=use_reranking,
use_colpali=use_colpali,
graph_name=None,
folder_name=folder_name,
end_user_id=end_user_id,
)
# Parallel approach
# 1. Standard vector search
vector_chunks = await document_service.retrieve_chunks(
query, auth, filters, k, min_score, use_reranking, use_colpali
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
)
logger.info(f"Vector search retrieved {len(vector_chunks)} chunks")
@ -990,7 +1040,7 @@ class GraphService:
# Get specific chunks containing these entities
graph_chunks = await self._retrieve_entity_chunks(
expanded_entities, auth, filters, document_service
expanded_entities, auth, filters, document_service, folder_name, end_user_id
)
logger.info(f"Retrieved {len(graph_chunks)} chunks containing relevant entities")
@ -1015,6 +1065,8 @@ class GraphService:
auth,
graph_name,
prompt_overrides,
folder_name=folder_name,
end_user_id=end_user_id,
)
return completion_response
@ -1143,8 +1195,13 @@ class GraphService:
auth: AuthContext,
filters: Optional[Dict[str, Any]],
document_service,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> List[ChunkResult]:
"""Retrieve chunks containing the specified entities."""
# Initialize filters if None
if filters is None:
filters = {}
if not entities:
return []
@ -1158,9 +1215,9 @@ class GraphService:
# Get unique document IDs for authorization check
doc_ids = {doc_id for doc_id, _ in entity_chunk_sources}
# Check document authorization
documents = await document_service.batch_retrieve_documents(list(doc_ids), auth)
# Check document authorization with system filters
documents = await document_service.batch_retrieve_documents(list(doc_ids), auth, folder_name, end_user_id)
# Apply filters if needed
authorized_doc_ids = {
@ -1178,7 +1235,7 @@ class GraphService:
# Retrieve and return chunks if we have any valid sources
return (
await document_service.batch_retrieve_chunks(chunk_sources, auth)
await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name=folder_name, end_user_id=end_user_id)
if chunk_sources
else []
)
@ -1198,7 +1255,7 @@ class GraphService:
chunk.score = min(1.0, (getattr(chunk, "score", 0.7) or 0.7) * 1.05)
# Keep the higher-scored version
if chunk_key not in all_chunks or chunk.score > all_chunks[chunk_key].score:
if chunk_key not in all_chunks or chunk.score > (getattr(all_chunks.get(chunk_key), "score", 0) or 0):
all_chunks[chunk_key] = chunk
# Convert to list, sort by score, and return top k
@ -1330,6 +1387,8 @@ class GraphService:
auth: Optional[AuthContext] = None,
graph_name: Optional[str] = None,
prompt_overrides: Optional[QueryPromptOverrides] = None,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> CompletionResponse:
"""Generate completion using the retrieved chunks and optional path information."""
if not chunks:
@ -1370,6 +1429,8 @@ class GraphService:
max_tokens=max_tokens,
temperature=temperature,
prompt_template=custom_prompt_template,
folder_name=folder_name,
end_user_id=end_user_id,
)
# Get completion from model
@ -1387,6 +1448,7 @@ class GraphService:
# Include graph metadata if paths were requested
if include_paths:
# Initialize metadata if it doesn't exist
if not hasattr(response, "metadata") or response.metadata is None:
response.metadata = {}

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
TEST_DATA_DIR = Path(__file__).parent / "test_data"
JWT_SECRET = "your-secret-key-for-signing-tokens"
TEST_USER_ID = "test_user"
TEST_POSTGRES_URI = "postgresql+asyncpg://postgres:postgres@localhost:5432/morphik_test"
TEST_POSTGRES_URI = "postgresql+asyncpg://morphik@localhost:5432/morphik_test"
@pytest.fixture(scope="session")
@ -261,6 +261,41 @@ async def test_ingest_text_document_with_metadata(client: AsyncClient, content:
return data["external_id"]
@pytest.mark.asyncio
async def test_ingest_text_document_folder_user(
client: AsyncClient,
content: str = "Test content for document ingestion with folder and user scoping",
metadata: dict = None,
folder_name: str = "test_folder",
end_user_id: str = "test_user@example.com"
):
"""Test ingesting a text document with folder and user scoping"""
headers = create_auth_header()
response = await client.post(
"/ingest/text",
json={
"content": content,
"metadata": metadata or {},
"folder_name": folder_name,
"end_user_id": end_user_id
},
headers=headers,
)
assert response.status_code == 200
data = response.json()
assert "external_id" in data
assert data["content_type"] == "text/plain"
assert data["system_metadata"]["folder_name"] == folder_name
assert data["system_metadata"]["end_user_id"] == end_user_id
for key, value in (metadata or {}).items():
assert data["metadata"][key] == value
return data["external_id"]
@pytest.mark.asyncio
async def test_ingest_pdf(client: AsyncClient):
"""Test ingesting a pdf"""
@ -1507,6 +1542,196 @@ async def test_query_with_graph(client: AsyncClient):
assert response_no_graph.status_code == 200
@pytest.mark.asyncio
async def test_graph_with_folder_and_user_scope(client: AsyncClient):
"""Test knowledge graph with folder and user scoping."""
headers = create_auth_header()
# Test folder
folder_name = "test_graph_folder"
# Test user
user_id = "graph_test_user@example.com"
# Ingest documents into folder with user scope using our helper function
doc_id1 = await test_ingest_text_document_folder_user(
client,
content="Tesla is an electric vehicle manufacturer. Elon Musk is the CEO of Tesla.",
metadata={"graph_scope_test": True},
folder_name=folder_name,
end_user_id=user_id
)
doc_id2 = await test_ingest_text_document_folder_user(
client,
content="SpaceX develops spacecraft and rockets. Elon Musk is also the CEO of SpaceX.",
metadata={"graph_scope_test": True},
folder_name=folder_name,
end_user_id=user_id
)
# Also ingest a document outside the folder/user scope
_ = await test_ingest_text_document_with_metadata(
client,
content="Elon Musk also founded Neuralink, a neurotechnology company.",
metadata={"graph_scope_test": True}
)
# Create a graph with folder and user scope
graph_name = "test_scoped_graph"
response = await client.post(
"/graph/create",
json={
"name": graph_name,
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert response.status_code == 200
graph = response.json()
# Verify graph was created with proper scoping
assert graph["name"] == graph_name
assert len(graph["document_ids"]) == 2
assert all(doc_id in graph["document_ids"] for doc_id in [doc_id1, doc_id2])
# Verify we have the expected entities
entity_labels = [entity["label"].lower() for entity in graph["entities"]]
assert any("tesla" in label for label in entity_labels)
assert any("spacex" in label for label in entity_labels)
assert any("elon musk" in label for label in entity_labels)
# First, let's check the retrieved chunks directly to verify scope is working
retrieve_response = await client.post(
"/retrieve/chunks",
json={
"query": "What companies does Elon Musk lead?",
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert retrieve_response.status_code == 200
retrieved_chunks = retrieve_response.json()
# Verify that none of the retrieved chunks contain "Neuralink"
for chunk in retrieved_chunks:
assert "neuralink" not in chunk["content"].lower()
# First try querying without a graph to see if RAG works with just folder/user scope
response_no_graph = await client.post(
"/query",
json={
"query": "What companies does Elon Musk lead?",
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert response_no_graph.status_code == 200
result_no_graph = response_no_graph.json()
# Verify the completion has the expected content
completion_no_graph = result_no_graph["completion"].lower()
print("Completion without graph:")
print(completion_no_graph)
assert "tesla" in completion_no_graph
assert "spacex" in completion_no_graph
assert "neuralink" not in completion_no_graph
# Now test querying with graph and folder/user scope
response = await client.post(
"/query",
json={
"query": "What companies does Elon Musk lead?",
"graph_name": graph_name,
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert response.status_code == 200
result = response.json()
# Log source chunks and graph information used for completion
print("\nSource chunks for graph-based completion:")
for source in result["sources"]:
print(f"Document ID: {source['document_id']}, Chunk: {source['chunk_number']}")
# Check if there's graph metadata in the response
if result.get("metadata") and "graph" in result.get("metadata", {}):
print("\nGraph metadata used:")
print(result["metadata"]["graph"])
# Verify the completion has the expected content
completion = result["completion"].lower()
print("\nCompletion with graph:")
print(completion)
assert "tesla" in completion
assert "spacex" in completion
# Verify Neuralink isn't included (it was outside folder/user scope)
assert "neuralink" not in completion
# Test updating the graph with folder and user scope
doc_id3 = await test_ingest_text_document_folder_user(
client,
content="The Boring Company was founded by Elon Musk in 2016.",
metadata={"graph_scope_test": True},
folder_name=folder_name,
end_user_id=user_id
)
# Update the graph
update_response = await client.post(
f"/graph/{graph_name}/update",
json={
"additional_documents": [doc_id3],
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert update_response.status_code == 200
updated_graph = update_response.json()
# Verify graph was updated
assert updated_graph["name"] == graph_name
assert len(updated_graph["document_ids"]) == 3
assert all(doc_id in updated_graph["document_ids"] for doc_id in [doc_id1, doc_id2, doc_id3])
# Verify new entity was added
updated_entity_labels = [entity["label"].lower() for entity in updated_graph["entities"]]
assert any("boring company" in label for label in updated_entity_labels)
# Test querying with updated graph
response = await client.post(
"/query",
json={
"query": "List all companies founded or led by Elon Musk",
"graph_name": graph_name,
"folder_name": folder_name,
"end_user_id": user_id
},
headers=headers,
)
assert response.status_code == 200
updated_result = response.json()
# Verify the completion includes the new company
updated_completion = updated_result["completion"].lower()
assert "tesla" in updated_completion
assert "spacex" in updated_completion
assert "boring company" in updated_completion
@pytest.mark.asyncio
async def test_batch_ingest_with_shared_metadata(
client: AsyncClient
@ -1804,6 +2029,429 @@ async def test_batch_ingest_sequential_vs_parallel(
assert len(result["errors"]) == 0
@pytest.mark.asyncio
async def test_folder_scoping(client: AsyncClient):
"""Test document operations with folder scoping."""
headers = create_auth_header()
# Test folder 1
folder1_name = "test_folder_1"
folder1_content = "This is content in test folder 1."
# Test folder 2
folder2_name = "test_folder_2"
folder2_content = "This is different content in test folder 2."
# Ingest document into folder 1 using our helper function
doc1_id = await test_ingest_text_document_folder_user(
client,
content=folder1_content,
metadata={"folder_test": True},
folder_name=folder1_name,
end_user_id=None
)
# Get the document to verify
response = await client.get(f"/documents/{doc1_id}", headers=headers)
assert response.status_code == 200
doc1 = response.json()
assert doc1["system_metadata"]["folder_name"] == folder1_name
# Ingest document into folder 2 using our helper function
doc2_id = await test_ingest_text_document_folder_user(
client,
content=folder2_content,
metadata={"folder_test": True},
folder_name=folder2_name,
end_user_id=None
)
# Get the document to verify
response = await client.get(f"/documents/{doc2_id}", headers=headers)
assert response.status_code == 200
doc2 = response.json()
assert doc2["system_metadata"]["folder_name"] == folder2_name
# Verify we can get documents by folder
response = await client.post(
"/documents",
json={"folder_test": True},
headers=headers,
params={"folder_name": folder1_name}
)
assert response.status_code == 200
folder1_docs = response.json()
assert len(folder1_docs) == 1
assert folder1_docs[0]["external_id"] == doc1_id
# Verify other folder's document isn't in results
assert not any(doc["external_id"] == doc2_id for doc in folder1_docs)
# Test querying with folder scope
response = await client.post(
"/query",
json={
"query": "What folder is this content in?",
"folder_name": folder1_name
},
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert "completion" in result
# Test folder-specific chunk retrieval
response = await client.post(
"/retrieve/chunks",
json={
"query": "folder content",
"folder_name": folder2_name
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
assert folder2_content in chunks[0]["content"]
# Test document update with folder preservation
updated_content = "This is updated content in test folder 1."
response = await client.post(
f"/documents/{doc1_id}/update_text",
json={
"content": updated_content,
"metadata": {"updated": True},
"folder_name": folder1_name # This should match original folder
},
headers=headers,
)
assert response.status_code == 200
updated_doc = response.json()
assert updated_doc["system_metadata"]["folder_name"] == folder1_name
@pytest.mark.asyncio
async def test_user_scoping(client: AsyncClient):
"""Test document operations with end-user scoping."""
headers = create_auth_header()
# Test user 1
user1_id = "test_user_1@example.com"
user1_content = "This is content created by test user 1."
# Test user 2
user2_id = "test_user_2@example.com"
user2_content = "This is different content created by test user 2."
# Ingest document for user 1 using our helper function
doc1_id = await test_ingest_text_document_folder_user(
client,
content=user1_content,
metadata={"user_test": True},
folder_name=None,
end_user_id=user1_id
)
# Get the document to verify
response = await client.get(f"/documents/{doc1_id}", headers=headers)
assert response.status_code == 200
doc1 = response.json()
assert doc1["system_metadata"]["end_user_id"] == user1_id
# Ingest document for user 2 using our helper function
doc2_id = await test_ingest_text_document_folder_user(
client,
content=user2_content,
metadata={"user_test": True},
folder_name=None,
end_user_id=user2_id
)
# Get the document to verify
response = await client.get(f"/documents/{doc2_id}", headers=headers)
assert response.status_code == 200
doc2 = response.json()
assert doc2["system_metadata"]["end_user_id"] == user2_id
# Verify we can get documents by user
response = await client.post(
"/documents",
json={"user_test": True},
headers=headers,
params={"end_user_id": user1_id}
)
assert response.status_code == 200
user1_docs = response.json()
assert len(user1_docs) == 1
assert user1_docs[0]["external_id"] == doc1_id
# Verify other user's document isn't in results
assert not any(doc["external_id"] == doc2_id for doc in user1_docs)
# Test querying with user scope
response = await client.post(
"/query",
json={
"query": "What is my content?",
"end_user_id": user1_id
},
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert "completion" in result
# Test updating document with user preservation
updated_content = "This is updated content by test user 1."
response = await client.post(
f"/documents/{doc1_id}/update_text",
json={
"content": updated_content,
"metadata": {"updated": True},
"end_user_id": user1_id # Should preserve the user
},
headers=headers,
)
assert response.status_code == 200
updated_doc = response.json()
assert updated_doc["system_metadata"]["end_user_id"] == user1_id
@pytest.mark.asyncio
async def test_combined_folder_and_user_scoping(client: AsyncClient):
"""Test document operations with combined folder and user scoping."""
headers = create_auth_header()
# Test folder
folder_name = "test_combined_folder"
# Test users
user1_id = "combined_test_user_1@example.com"
user2_id = "combined_test_user_2@example.com"
# Ingest document for user 1 in folder using our new helper function
user1_content = "This is content by user 1 in the combined test folder."
doc1_id = await test_ingest_text_document_folder_user(
client,
content=user1_content,
metadata={"combined_test": True},
folder_name=folder_name,
end_user_id=user1_id
)
# Get the document to verify
response = await client.get(f"/documents/{doc1_id}", headers=headers)
assert response.status_code == 200
doc1 = response.json()
assert doc1["system_metadata"]["folder_name"] == folder_name
assert doc1["system_metadata"]["end_user_id"] == user1_id
# Ingest document for user 2 in folder using our new helper function
user2_content = "This is content by user 2 in the combined test folder."
doc2_id = await test_ingest_text_document_folder_user(
client,
content=user2_content,
metadata={"combined_test": True},
folder_name=folder_name,
end_user_id=user2_id
)
# Get the document to verify
response = await client.get(f"/documents/{doc2_id}", headers=headers)
assert response.status_code == 200
doc2 = response.json()
assert doc2["system_metadata"]["folder_name"] == folder_name
assert doc2["system_metadata"]["end_user_id"] == user2_id
# Get all documents in folder
response = await client.post(
"/documents",
json={"combined_test": True},
headers=headers,
params={"folder_name": folder_name}
)
assert response.status_code == 200
folder_docs = response.json()
assert len(folder_docs) == 2
# Get user 1's documents in the folder
response = await client.post(
"/documents",
json={"combined_test": True},
headers=headers,
params={"folder_name": folder_name, "end_user_id": user1_id}
)
assert response.status_code == 200
user1_folder_docs = response.json()
assert len(user1_folder_docs) == 1
assert user1_folder_docs[0]["external_id"] == doc1_id
# Test querying with combined scope
response = await client.post(
"/query",
json={
"query": "What is in this folder for this user?",
"folder_name": folder_name,
"end_user_id": user2_id
},
headers=headers,
)
assert response.status_code == 200
result = response.json()
assert "completion" in result
# Test retrieving chunks with combined scope
response = await client.post(
"/retrieve/chunks",
json={
"query": "combined test folder",
"folder_name": folder_name,
"end_user_id": user1_id
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
# Should only have user 1's content
assert any(user1_content in chunk["content"] for chunk in chunks)
assert not any(user2_content in chunk["content"] for chunk in chunks)
@pytest.mark.asyncio
async def test_system_metadata_filter_behavior(client: AsyncClient):
"""Test detailed behavior of system_metadata filtering."""
headers = create_auth_header()
# Create documents with different system metadata combinations
# Document with folder only
folder_only_content = "This document has only folder in system metadata."
folder_only_id = await test_ingest_text_document_folder_user(
client,
content=folder_only_content,
metadata={"filter_test": True},
folder_name="test_filter_folder",
end_user_id=None # Only folder, no user
)
# Get the document to verify
response = await client.get(f"/documents/{folder_only_id}", headers=headers)
assert response.status_code == 200
folder_only_doc = response.json()
# Document with user only
user_only_content = "This document has only user in system metadata."
user_only_id = await test_ingest_text_document_folder_user(
client,
content=user_only_content,
metadata={"filter_test": True},
folder_name=None, # No folder, only user
end_user_id="test_filter_user@example.com"
)
# Get the document to verify
response = await client.get(f"/documents/{user_only_id}", headers=headers)
assert response.status_code == 200
user_only_doc = response.json()
# Document with both folder and user
combined_content = "This document has both folder and user in system metadata."
combined_id = await test_ingest_text_document_folder_user(
client,
content=combined_content,
metadata={"filter_test": True},
folder_name="test_filter_folder",
end_user_id="test_filter_user@example.com"
)
# Get the document to verify
response = await client.get(f"/documents/{combined_id}", headers=headers)
assert response.status_code == 200
combined_doc = response.json()
# Test queries with different filter combinations
# Filter by folder only
response = await client.post(
"/documents",
json={"filter_test": True},
headers=headers,
params={"folder_name": "test_filter_folder"}
)
assert response.status_code == 200
folder_filtered_docs = response.json()
folder_doc_ids = [doc["external_id"] for doc in folder_filtered_docs]
assert folder_only_id in folder_doc_ids
assert combined_id in folder_doc_ids
assert user_only_id not in folder_doc_ids
# Filter by user only
response = await client.post(
"/documents",
json={"filter_test": True},
headers=headers,
params={"end_user_id": "test_filter_user@example.com"}
)
assert response.status_code == 200
user_filtered_docs = response.json()
user_doc_ids = [doc["external_id"] for doc in user_filtered_docs]
assert user_only_id in user_doc_ids
assert combined_id in user_doc_ids
assert folder_only_id not in user_doc_ids
# Filter by both folder and user
response = await client.post(
"/documents",
json={"filter_test": True},
headers=headers,
params={
"folder_name": "test_filter_folder",
"end_user_id": "test_filter_user@example.com"
}
)
assert response.status_code == 200
combined_filtered_docs = response.json()
combined_doc_ids = [doc["external_id"] for doc in combined_filtered_docs]
assert len(combined_filtered_docs) == 1
assert combined_id in combined_doc_ids
assert folder_only_id not in combined_doc_ids
assert user_only_id not in combined_doc_ids
# Test with chunk retrieval
response = await client.post(
"/retrieve/chunks",
json={
"query": "system metadata",
"folder_name": "test_filter_folder",
"end_user_id": "test_filter_user@example.com"
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
# Should only have the combined document content
assert any(combined_content in chunk["content"] for chunk in chunks)
assert not any(folder_only_content in chunk["content"] for chunk in chunks)
assert not any(user_only_content in chunk["content"] for chunk in chunks)
@pytest.mark.asyncio
async def test_delete_document(client: AsyncClient):
"""Test deleting a document and verifying it's gone."""

View File

@ -1,183 +0,0 @@
from typing import List, Optional, Tuple
import logging
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import PyMongoError
from .base_vector_store import BaseVectorStore
from core.models.chunk import DocumentChunk
logger = logging.getLogger(__name__)
class MongoDBAtlasVectorStore(BaseVectorStore):
"""MongoDB Atlas Vector Search implementation."""
def __init__(
self,
uri: str,
database_name: str,
collection_name: str = "document_chunks",
index_name: str = "vector_index",
):
"""Initialize MongoDB connection for vector storage."""
self.client = AsyncIOMotorClient(uri)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
self.index_name = index_name
async def initialize(self):
"""Initialize vector search index if needed."""
try:
# Create basic indexes
await self.collection.create_index("document_id")
await self.collection.create_index("chunk_number")
# Note: Vector search index must be created via Atlas UI or API
# as it requires specific configuration
logger.info("MongoDB vector store indexes initialized")
return True
except PyMongoError as e:
logger.error(f"Error initializing vector store indexes: {str(e)}")
return False
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
"""Store document chunks with their embeddings."""
try:
if not chunks:
return True, []
# Convert chunks to dicts
documents = []
for chunk in chunks:
doc = chunk.model_dump()
# Ensure we have required fields
if not doc.get("embedding"):
logger.error(
f"Missing embedding for chunk " f"{chunk.document_id}-{chunk.chunk_number}"
)
continue
documents.append(doc)
if documents:
# Use ordered=False to continue even if some inserts fail
result = await self.collection.insert_many(documents, ordered=False)
return len(result.inserted_ids) > 0, [str(id) for id in result.inserted_ids]
else:
logger.error(f"No documents to store - here is the input: {chunks}")
return False, []
except PyMongoError as e:
logger.error(f"Error storing embeddings: {str(e)}")
return False, []
async def query_similar(
self,
query_embedding: List[float],
k: int,
doc_ids: Optional[List[str]] = None,
) -> List[DocumentChunk]:
"""Find similar chunks using MongoDB Atlas Vector Search."""
try:
logger.debug(
f"Searching in database {self.db.name} " f"collection {self.collection.name}"
)
logger.debug(f"Query vector looks like: {query_embedding}")
logger.debug(f"Doc IDs: {doc_ids}")
logger.debug(f"K is: {k}")
logger.debug(f"Index is: {self.index_name}")
# Vector search pipeline
pipeline = [
{
"$vectorSearch": {
"index": self.index_name,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": k * 40, # Get more candidates
"limit": k,
"filter": {"document_id": {"$in": doc_ids}} if doc_ids else {},
}
},
{
"$project": {
"score": {"$meta": "vectorSearchScore"},
"document_id": 1,
"chunk_number": 1,
"content": 1,
"metadata": 1,
"_id": 0,
}
},
]
# Execute search
cursor = self.collection.aggregate(pipeline)
chunks = []
async for result in cursor:
chunk = DocumentChunk(
document_id=result["document_id"],
chunk_number=result["chunk_number"],
content=result["content"],
embedding=[], # Don't send embeddings back
metadata=result.get("metadata", {}),
score=result.get("score", 0.0),
)
chunks.append(chunk)
return chunks
except PyMongoError as e:
logger.error(f"MongoDB error: {e._message}")
logger.error(f"Error querying similar chunks: {str(e)}")
raise e
async def get_chunks_by_id(
self,
chunk_identifiers: List[Tuple[str, int]],
) -> List[DocumentChunk]:
"""
Retrieve specific chunks by document ID and chunk number in a single database query.
Args:
chunk_identifiers: List of (document_id, chunk_number) tuples
Returns:
List of DocumentChunk objects
"""
try:
if not chunk_identifiers:
return []
# Create a query with $or to find multiple chunks in a single query
query = {"$or": []}
for doc_id, chunk_num in chunk_identifiers:
query["$or"].append({
"document_id": doc_id,
"chunk_number": chunk_num
})
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
# Find all matching chunks in a single database query
cursor = self.collection.find(query)
chunks = []
async for result in cursor:
chunk = DocumentChunk(
document_id=result["document_id"],
chunk_number=result["chunk_number"],
content=result["content"],
embedding=[], # Don't send embeddings back
metadata=result.get("metadata", {}),
score=0.0, # No relevance score for direct retrieval
)
chunks.append(chunk)
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
return chunks
except PyMongoError as e:
logger.error(f"Error retrieving chunks by ID: {str(e)}")
return []

View File

@ -0,0 +1,82 @@
import os
from dotenv import load_dotenv
from morphik import Morphik
# Load environment variables
load_dotenv()
# Connect to Morphik
db = Morphik(os.getenv("MORPHIK_URI"), timeout=10000, is_local=True)
print("========== Customer Support Example ==========")
# Create a folder for application data
app_folder = db.create_folder("customer-support")
print(f"Created folder: {app_folder.name}")
# Ingest documents into the folder
folder_doc = app_folder.ingest_text(
"Customer reported an issue with login functionality. Steps to reproduce: "
"1. Go to login page, 2. Enter credentials, 3. Click login button.",
filename="ticket-001.txt",
metadata={"category": "bug", "priority": "high", "status": "open"}
)
print(f"Ingested document into folder: {folder_doc.external_id}")
# Perform a query in the folder context
folder_response = app_folder.query(
"What issues have been reported?",
k=2
)
print("\nFolder Query Results:")
print(folder_response.completion)
# Get statistics for the folder
folder_docs = app_folder.list_documents()
print(f"\nFolder Statistics: {len(folder_docs)} documents in '{app_folder.name}'")
print("\n========== User Scoping Example ==========")
# Create a user scope
user_email = "support@example.com"
user = db.signin(user_email)
print(f"Created user scope for: {user.end_user_id}")
# Ingest a document as this user
user_doc = user.ingest_text(
"User requested information about premium features. They are interested in the collaboration tools.",
filename="inquiry-001.txt",
metadata={"category": "inquiry", "priority": "medium", "status": "open"}
)
print(f"Ingested document as user: {user_doc.external_id}")
# Query as this user
user_response = user.query(
"What customer inquiries do we have?",
k=2
)
print("\nUser Query Results:")
print(user_response.completion)
# Get documents for this user
user_docs = user.list_documents()
print(f"\nUser Statistics: {len(user_docs)} documents for user '{user.end_user_id}'")
print("\n========== Combined Folder and User Scoping ==========")
# Create a user scoped to a specific folder
folder_user = app_folder.signin(user_email)
print(f"Created user scope for {folder_user.end_user_id} in folder {folder_user.folder_name}")
# Ingest a document as this user in the folder context
folder_user_doc = folder_user.ingest_text(
"Customer called to follow up on ticket-001. They are still experiencing the login issue on Chrome.",
filename="ticket-002.txt",
metadata={"category": "follow-up", "priority": "high", "status": "open"}
)
print(f"Ingested document as user in folder: {folder_user_doc.external_id}")
# Query as this user in the folder context
folder_user_response = folder_user.query(
"What high priority issues require attention?",
k=2
)
print("\nFolder User Query Results:")
print(folder_user_response.completion)

View File

@ -9,15 +9,12 @@ import boto3
import botocore
import tomli # for reading toml files
from dotenv import find_dotenv, load_dotenv
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, OperationFailure
from pymongo.operations import SearchIndexModel
# Force reload of environment variables
load_dotenv(find_dotenv(), override=True)
# Set up argument parser
parser = argparse.ArgumentParser(description="Setup S3 bucket and MongoDB collections")
parser = argparse.ArgumentParser(description="Setup S3 bucket")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument("--quiet", action="store_true", help="Only show warning and error logs")
args = parser.parse_args()
@ -48,16 +45,6 @@ with open(config_path, "rb") as f:
STORAGE_PROVIDER = CONFIG["storage"]["provider"]
DATABASE_PROVIDER = CONFIG["database"]["provider"]
# MongoDB specific config
if "mongodb" in CONFIG["database"]:
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
DOCUMENTS_COLLECTION = "documents"
CHUNKS_COLLECTION = "document_chunks"
if "mongodb" in CONFIG["vector_store"]:
VECTOR_DIMENSIONS = CONFIG["embedding"]["dimensions"]
VECTOR_INDEX_NAME = "vector_index"
SIMILARITY_METRIC = CONFIG["embedding"]["similarity_metric"]
# Extract storage-specific configuration
if STORAGE_PROVIDER == "aws-s3":
DEFAULT_REGION = CONFIG["storage"]["region"]
@ -117,69 +104,6 @@ def bucket_exists(s3_client, bucket_name):
# raise e
def setup_mongodb():
"""
Set up MongoDB database, documents collection, and vector index on documents_chunk collection.
"""
# Load MongoDB URI from .env file
mongo_uri = os.getenv("MONGODB_URI")
if not mongo_uri:
raise ValueError("MONGODB_URI not found in .env file.")
try:
# Connect to MongoDB
client = MongoClient(mongo_uri)
client.admin.command("ping") # Check connection
LOGGER.info("Connected to MongoDB successfully.")
# Create or access the database
db = client[DATABASE_NAME]
LOGGER.info(f"Database '{DATABASE_NAME}' ready.")
# Create 'documents' collection
if DOCUMENTS_COLLECTION not in db.list_collection_names():
db.create_collection(DOCUMENTS_COLLECTION)
LOGGER.info(f"Collection '{DOCUMENTS_COLLECTION}' created.")
else:
LOGGER.info(f"Collection '{DOCUMENTS_COLLECTION}' already exists.")
# Create 'documents_chunk' collection with vector index
if CHUNKS_COLLECTION not in db.list_collection_names():
db.create_collection(CHUNKS_COLLECTION)
LOGGER.info(f"Collection '{CHUNKS_COLLECTION}' created.")
else:
LOGGER.info(f"Collection '{CHUNKS_COLLECTION}' already exists.")
vector_index_definition = {
"fields": [
{
"numDimensions": VECTOR_DIMENSIONS,
"path": "embedding",
"similarity": SIMILARITY_METRIC,
"type": "vector",
},
{"path": "document_id", "type": "filter"},
]
}
vector_index = SearchIndexModel(
name=VECTOR_INDEX_NAME,
definition=vector_index_definition,
type="vectorSearch",
)
db[CHUNKS_COLLECTION].create_search_index(model=vector_index)
LOGGER.info("Vector index 'vector_index' created on 'documents_chunk' collection.")
except ConnectionFailure:
LOGGER.error("Failed to connect to MongoDB. Check your MongoDB URI and network connection.")
except OperationFailure as e:
LOGGER.error(f"MongoDB operation failed: {e}")
except Exception as e:
LOGGER.error(f"Unexpected error: {e}")
finally:
client.close()
LOGGER.info("MongoDB connection closed.")
def setup():
# Setup S3 if configured
if STORAGE_PROVIDER == "aws-s3":
@ -188,16 +112,11 @@ def setup():
LOGGER.info("S3 bucket setup completed.")
# Setup database based on provider
match DATABASE_PROVIDER:
case "mongodb":
LOGGER.info("Setting up MongoDB...")
setup_mongodb()
LOGGER.info("MongoDB setup completed.")
case "postgres":
LOGGER.info("Postgres is setup on database intialization - nothing to do here!")
case _:
LOGGER.error(f"Unsupported database provider: {DATABASE_PROVIDER}")
raise ValueError(f"Unsupported database provider: {DATABASE_PROVIDER}")
if DATABASE_PROVIDER != "postgres":
LOGGER.error(f"Unsupported database provider: {DATABASE_PROVIDER}")
raise ValueError(f"Unsupported database provider: {DATABASE_PROVIDER}")
LOGGER.info("Postgres is setup on database initialization - nothing to do here!")
LOGGER.info("Setup completed successfully. Feel free to start the server now!")

View File

@ -151,7 +151,6 @@ matplotlib-inline==0.1.7
mdurl==0.1.2
monotonic==1.6
more-itertools==10.5.0
motor==3.4.0
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
@ -237,7 +236,6 @@ pydeck==0.9.1
pyee==12.1.1
Pygments==2.18.0
PyJWT==2.9.0
pymongo==4.7.1
pypandoc==1.13
pyparsing==3.1.2
pypdf==4.3.1

View File

@ -1,72 +0,0 @@
from pymongo import MongoClient
from dotenv import load_dotenv
import os
import datetime
def test_mongo_operations():
# Load environment variables
load_dotenv()
# Get MongoDB URI from environment variable
mongo_uri = os.getenv("MONGODB_URI")
if not mongo_uri:
raise ValueError("MONGODB_URI environment variable not set")
try:
# Connect to MongoDB
client = MongoClient(mongo_uri)
# Test connection
client.admin.command("ping")
print("✅ Connected successfully to MongoDB")
# Get database and collection
db = client.brandsyncaidb # Using a test database
collection = db.kb_chunked_embeddings
# Insert a single document
test_doc = {
"name": "Test Document",
"timestamp": datetime.datetime.now(),
"value": 42,
}
result = collection.insert_one(test_doc)
print(f"✅ Inserted document with ID: {result.inserted_id}")
# Insert multiple documents
test_docs = [
{"name": "Doc 1", "value": 1},
{"name": "Doc 2", "value": 2},
{"name": "Doc 3", "value": 3},
]
result = collection.insert_many(test_docs)
print(f"✅ Inserted {len(result.inserted_ids)} documents")
# Retrieve documents
print("\nRetrieving documents:")
for doc in collection.find():
print(f"Found document: {doc}")
# Find specific documents
print("\nFinding documents with value >= 2:")
query = {"value": {"$gte": 2}}
for doc in collection.find(query):
print(f"Found document: {doc}")
# Clean up - delete all test documents
# DON'T DELETE IF It'S BRANDSYNCAI
# result = collection.delete_many({})
print(f"\n✅ Cleaned up {result.deleted_count} test documents")
except Exception as e:
print(f"❌ Error: {str(e)}")
finally:
client.close()
print("\n✅ Connection closed")
if __name__ == "__main__":
test_mongo_operations()

View File

@ -12,4 +12,4 @@ __all__ = [
"Document",
]
__version__ = "0.1.0"
__version__ = "0.1.2"

View File

@ -0,0 +1,507 @@
import base64
import io
import json
from io import BytesIO, IOBase
from PIL import Image
from PIL.Image import Image as PILImage
from pathlib import Path
from typing import Dict, Any, List, Optional, Union, Tuple, BinaryIO
from urllib.parse import urlparse
import jwt
from pydantic import BaseModel, Field
from .models import (
Document,
ChunkResult,
DocumentResult,
CompletionResponse,
IngestTextRequest,
ChunkSource,
Graph,
# Prompt override models
GraphPromptOverrides,
)
from .rules import Rule
# Type alias for rules
RuleOrDict = Union[Rule, Dict[str, Any]]
class FinalChunkResult(BaseModel):
content: str | PILImage = Field(..., description="Chunk content")
score: float = Field(..., description="Relevance score")
document_id: str = Field(..., description="Parent document ID")
chunk_number: int = Field(..., description="Chunk sequence number")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
content_type: str = Field(..., description="Content type")
filename: Optional[str] = Field(None, description="Original filename")
download_url: Optional[str] = Field(None, description="URL to download full document")
class Config:
arbitrary_types_allowed = True
class _MorphikClientLogic:
"""
Internal shared logic for Morphik clients.
This class contains the shared logic between synchronous and asynchronous clients.
It handles URL generation, request preparation, and response parsing.
"""
def __init__(self, uri: Optional[str] = None, timeout: int = 30, is_local: bool = False):
"""Initialize shared client logic"""
self._timeout = timeout
self._is_local = is_local
if uri:
self._setup_auth(uri)
else:
self._base_url = "http://localhost:8000"
self._auth_token = None
def _setup_auth(self, uri: str) -> None:
"""Setup authentication from URI"""
parsed = urlparse(uri)
if not parsed.netloc:
raise ValueError("Invalid URI format")
# Split host and auth parts
auth, host = parsed.netloc.split("@")
_, self._auth_token = auth.split(":")
# Set base URL
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
# Basic token validation
jwt.decode(self._auth_token, options={"verify_signature": False})
def _convert_rule(self, rule: RuleOrDict) -> Dict[str, Any]:
"""Convert a rule to a dictionary format"""
if hasattr(rule, "to_dict"):
return rule.to_dict()
return rule
def _get_url(self, endpoint: str) -> str:
"""Get the full URL for an API endpoint"""
return f"{self._base_url}/{endpoint.lstrip('/')}"
def _get_headers(self) -> Dict[str, str]:
"""Get base headers for API requests"""
headers = {"Content-Type": "application/json"}
return headers
# Request preparation methods
def _prepare_ingest_text_request(
self,
content: str,
filename: Optional[str],
metadata: Optional[Dict[str, Any]],
rules: Optional[List[RuleOrDict]],
use_colpali: bool,
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for ingest_text endpoint"""
rules_dict = [self._convert_rule(r) for r in (rules or [])]
payload = {
"content": content,
"filename": filename,
"metadata": metadata or {},
"rules": rules_dict,
"use_colpali": use_colpali,
}
if folder_name:
payload["folder_name"] = folder_name
if end_user_id:
payload["end_user_id"] = end_user_id
return payload
def _prepare_file_for_upload(
self,
file: Union[str, bytes, BinaryIO, Path],
filename: Optional[str] = None,
) -> Tuple[BinaryIO, str]:
"""
Process file input and return file object and filename.
Handles different file input types (str, Path, bytes, file-like object).
"""
if isinstance(file, (str, Path)):
file_path = Path(file)
if not file_path.exists():
raise ValueError(f"File not found: {file}")
filename = file_path.name if filename is None else filename
with open(file_path, "rb") as f:
content = f.read()
file_obj = BytesIO(content)
elif isinstance(file, bytes):
if filename is None:
raise ValueError("filename is required when ingesting bytes")
file_obj = BytesIO(file)
else:
if filename is None:
raise ValueError("filename is required when ingesting file object")
file_obj = file
return file_obj, filename
def _prepare_files_for_upload(
self,
files: List[Union[str, bytes, BinaryIO, Path]],
) -> List[Tuple[str, Tuple[str, BinaryIO]]]:
"""
Process multiple files and return a list of file objects in the format
expected by the API: [("files", (filename, file_obj)), ...]
"""
file_objects = []
for file in files:
if isinstance(file, (str, Path)):
path = Path(file)
file_objects.append(("files", (path.name, open(path, "rb"))))
elif isinstance(file, bytes):
file_objects.append(("files", ("file.bin", BytesIO(file))))
else:
file_objects.append(("files", (getattr(file, "name", "file.bin"), file)))
return file_objects
def _prepare_ingest_file_form_data(
self,
metadata: Optional[Dict[str, Any]],
rules: Optional[List[RuleOrDict]],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare form data for ingest_file endpoint"""
form_data = {
"metadata": json.dumps(metadata or {}),
"rules": json.dumps([self._convert_rule(r) for r in (rules or [])]),
}
if folder_name:
form_data["folder_name"] = folder_name
if end_user_id:
form_data["end_user_id"] = end_user_id
return form_data
def _prepare_ingest_files_form_data(
self,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]],
rules: Optional[List[RuleOrDict]],
use_colpali: bool,
parallel: bool,
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare form data for ingest_files endpoint"""
# Convert rules appropriately based on whether it's a flat list or list of lists
if rules:
if all(isinstance(r, list) for r in rules):
# List of lists - per-file rules
converted_rules = [
[self._convert_rule(r) for r in rule_list] for rule_list in rules
]
else:
# Flat list - shared rules for all files
converted_rules = [self._convert_rule(r) for r in rules]
else:
converted_rules = []
data = {
"metadata": json.dumps(metadata or {}),
"rules": json.dumps(converted_rules),
"use_colpali": str(use_colpali).lower() if use_colpali is not None else None,
"parallel": str(parallel).lower(),
}
if folder_name:
data["folder_name"] = folder_name
if end_user_id:
data["end_user_id"] = end_user_id
return data
def _prepare_query_request(
self,
query: str,
filters: Optional[Dict[str, Any]],
k: int,
min_score: float,
max_tokens: Optional[int],
temperature: Optional[float],
use_colpali: bool,
graph_name: Optional[str],
hop_depth: int,
include_paths: bool,
prompt_overrides: Optional[Dict],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for query endpoint"""
payload = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score,
"max_tokens": max_tokens,
"temperature": temperature,
"use_colpali": use_colpali,
"graph_name": graph_name,
"hop_depth": hop_depth,
"include_paths": include_paths,
"prompt_overrides": prompt_overrides,
}
if folder_name:
payload["folder_name"] = folder_name
if end_user_id:
payload["end_user_id"] = end_user_id
# Filter out None values before sending
return {k_p: v_p for k_p, v_p in payload.items() if v_p is not None}
def _prepare_retrieve_chunks_request(
self,
query: str,
filters: Optional[Dict[str, Any]],
k: int,
min_score: float,
use_colpali: bool,
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for retrieve_chunks endpoint"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score,
"use_colpali": use_colpali,
}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
def _prepare_retrieve_docs_request(
self,
query: str,
filters: Optional[Dict[str, Any]],
k: int,
min_score: float,
use_colpali: bool,
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for retrieve_docs endpoint"""
request = {
"query": query,
"filters": filters,
"k": k,
"min_score": min_score,
"use_colpali": use_colpali,
}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
def _prepare_list_documents_request(
self,
skip: int,
limit: int,
filters: Optional[Dict[str, Any]],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Prepare request for list_documents endpoint"""
params = {
"skip": skip,
"limit": limit,
}
if folder_name:
params["folder_name"] = folder_name
if end_user_id:
params["end_user_id"] = end_user_id
data = filters or {}
return params, data
def _prepare_batch_get_documents_request(
self, document_ids: List[str], folder_name: Optional[str], end_user_id: Optional[str]
) -> Dict[str, Any]:
"""Prepare request for batch_get_documents endpoint"""
if folder_name or end_user_id:
request = {"document_ids": document_ids}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
return document_ids # Return just IDs list if no scoping is needed
def _prepare_batch_get_chunks_request(
self,
sources: List[Union[ChunkSource, Dict[str, Any]]],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for batch_get_chunks endpoint"""
source_dicts = []
for source in sources:
if isinstance(source, dict):
source_dicts.append(source)
else:
source_dicts.append(source.model_dump())
if folder_name or end_user_id:
request = {"sources": source_dicts}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
return source_dicts # Return just sources list if no scoping is needed
def _prepare_create_graph_request(
self,
name: str,
filters: Optional[Dict[str, Any]],
documents: Optional[List[str]],
prompt_overrides: Optional[Union[GraphPromptOverrides, Dict[str, Any]]],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for create_graph endpoint"""
# Convert prompt_overrides to dict if it's a model
if prompt_overrides and isinstance(prompt_overrides, GraphPromptOverrides):
prompt_overrides = prompt_overrides.model_dump(exclude_none=True)
request = {
"name": name,
"filters": filters,
"documents": documents,
"prompt_overrides": prompt_overrides,
}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
def _prepare_update_graph_request(
self,
name: str,
additional_filters: Optional[Dict[str, Any]],
additional_documents: Optional[List[str]],
prompt_overrides: Optional[Union[GraphPromptOverrides, Dict[str, Any]]],
folder_name: Optional[str],
end_user_id: Optional[str],
) -> Dict[str, Any]:
"""Prepare request for update_graph endpoint"""
# Convert prompt_overrides to dict if it's a model
if prompt_overrides and isinstance(prompt_overrides, GraphPromptOverrides):
prompt_overrides = prompt_overrides.model_dump(exclude_none=True)
request = {
"additional_filters": additional_filters,
"additional_documents": additional_documents,
"prompt_overrides": prompt_overrides,
}
if folder_name:
request["folder_name"] = folder_name
if end_user_id:
request["end_user_id"] = end_user_id
return request
def _prepare_update_document_with_text_request(
self,
document_id: str,
content: str,
filename: Optional[str],
metadata: Optional[Dict[str, Any]],
rules: Optional[List],
update_strategy: str,
use_colpali: Optional[bool],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Prepare request for update_document_with_text endpoint"""
request = IngestTextRequest(
content=content,
filename=filename,
metadata=metadata or {},
rules=[self._convert_rule(r) for r in (rules or [])],
use_colpali=use_colpali if use_colpali is not None else True,
)
params = {}
if update_strategy != "add":
params["update_strategy"] = update_strategy
return params, request.model_dump()
# Response parsing methods
def _parse_document_response(self, response_json: Dict[str, Any]) -> Document:
"""Parse document response"""
return Document(**response_json)
def _parse_completion_response(self, response_json: Dict[str, Any]) -> CompletionResponse:
"""Parse completion response"""
return CompletionResponse(**response_json)
def _parse_document_list_response(self, response_json: List[Dict[str, Any]]) -> List[Document]:
"""Parse document list response"""
docs = [Document(**doc) for doc in response_json]
return docs
def _parse_document_result_list_response(
self, response_json: List[Dict[str, Any]]
) -> List[DocumentResult]:
"""Parse document result list response"""
return [DocumentResult(**r) for r in response_json]
def _parse_chunk_result_list_response(
self, response_json: List[Dict[str, Any]]
) -> List[FinalChunkResult]:
"""Parse chunk result list response"""
chunks = [ChunkResult(**r) for r in response_json]
final_chunks = []
for chunk in chunks:
content = chunk.content
if chunk.metadata.get("is_image"):
try:
# Handle data URI format "data:image/png;base64,..."
if content.startswith("data:"):
# Extract the base64 part after the comma
content = content.split(",", 1)[1]
# Now decode the base64 string
image_bytes = base64.b64decode(content)
content = Image.open(io.BytesIO(image_bytes))
except Exception:
# Fall back to using the content as text
content = chunk.content
final_chunks.append(
FinalChunkResult(
content=content,
score=chunk.score,
document_id=chunk.document_id,
chunk_number=chunk.chunk_number,
metadata=chunk.metadata,
content_type=chunk.content_type,
filename=chunk.filename,
download_url=chunk.download_url,
)
)
return final_chunks
def _parse_graph_response(self, response_json: Dict[str, Any]) -> Graph:
"""Parse graph response"""
return Graph(**response_json)
def _parse_graph_list_response(self, response_json: List[Dict[str, Any]]) -> List[Graph]:
"""Parse graph list response"""
return [Graph(**graph) for graph in response_json]

File diff suppressed because it is too large Load Diff

View File

@ -21,10 +21,10 @@ class Document(BaseModel):
default_factory=dict, description="Access control information"
)
chunk_ids: List[str] = Field(default_factory=list, description="IDs of document chunks")
# Client reference for update methods
_client = None
def update_with_text(
self,
content: str,
@ -36,7 +36,7 @@ class Document(BaseModel):
) -> "Document":
"""
Update this document with new text content using the specified strategy.
Args:
content: The new content to add
filename: Optional new filename for the document
@ -44,13 +44,15 @@ class Document(BaseModel):
rules: Optional list of rules to apply to the content
update_strategy: Strategy for updating the document (currently only 'add' is supported)
use_colpali: Whether to use multi-vector embedding
Returns:
Document: Updated document metadata
"""
if self._client is None:
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
raise ValueError(
"Document instance not connected to a client. Use a document returned from a Morphik client method."
)
return self._client.update_document_with_text(
document_id=self.external_id,
content=content,
@ -58,9 +60,9 @@ class Document(BaseModel):
metadata=metadata,
rules=rules,
update_strategy=update_strategy,
use_colpali=use_colpali
use_colpali=use_colpali,
)
def update_with_file(
self,
file: "Union[str, bytes, BinaryIO, Path]",
@ -72,7 +74,7 @@ class Document(BaseModel):
) -> "Document":
"""
Update this document with content from a file using the specified strategy.
Args:
file: File to add (path string, bytes, file object, or Path)
filename: Name of the file
@ -80,13 +82,15 @@ class Document(BaseModel):
rules: Optional list of rules to apply to the content
update_strategy: Strategy for updating the document (currently only 'add' is supported)
use_colpali: Whether to use multi-vector embedding
Returns:
Document: Updated document metadata
"""
if self._client is None:
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
raise ValueError(
"Document instance not connected to a client. Use a document returned from a Morphik client method."
)
return self._client.update_document_with_file(
document_id=self.external_id,
file=file,
@ -94,28 +98,29 @@ class Document(BaseModel):
metadata=metadata,
rules=rules,
update_strategy=update_strategy,
use_colpali=use_colpali
use_colpali=use_colpali,
)
def update_metadata(
self,
metadata: Dict[str, Any],
) -> "Document":
"""
Update this document's metadata only.
Args:
metadata: Metadata to update
Returns:
Document: Updated document metadata
"""
if self._client is None:
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
raise ValueError(
"Document instance not connected to a client. Use a document returned from a Morphik client method."
)
return self._client.update_document_metadata(
document_id=self.external_id,
metadata=metadata
document_id=self.external_id, metadata=metadata
)
@ -159,7 +164,7 @@ class DocumentResult(BaseModel):
class ChunkSource(BaseModel):
"""Source information for a chunk used in completion"""
document_id: str = Field(..., description="ID of the source document")
chunk_number: int = Field(..., description="Chunk number within the document")
score: Optional[float] = Field(None, description="Relevance score")
@ -194,7 +199,9 @@ class Entity(BaseModel):
type: str = Field(..., description="Entity type")
properties: Dict[str, Any] = Field(default_factory=dict, description="Entity properties")
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
chunk_sources: Dict[str, List[int]] = Field(default_factory=dict, description="Source chunk numbers by document ID")
chunk_sources: Dict[str, List[int]] = Field(
default_factory=dict, description="Source chunk numbers by document ID"
)
def __hash__(self):
return hash(self.id)
@ -213,7 +220,9 @@ class Relationship(BaseModel):
target_id: str = Field(..., description="Target entity ID")
type: str = Field(..., description="Relationship type")
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
chunk_sources: Dict[str, List[int]] = Field(default_factory=dict, description="Source chunk numbers by document ID")
chunk_sources: Dict[str, List[int]] = Field(
default_factory=dict, description="Source chunk numbers by document ID"
)
def __hash__(self):
return hash(self.id)
@ -230,10 +239,14 @@ class Graph(BaseModel):
id: str = Field(..., description="Unique graph identifier")
name: str = Field(..., description="Graph name")
entities: List[Entity] = Field(default_factory=list, description="Entities in the graph")
relationships: List[Relationship] = Field(default_factory=list, description="Relationships in the graph")
relationships: List[Relationship] = Field(
default_factory=list, description="Relationships in the graph"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Graph metadata")
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
filters: Optional[Dict[str, Any]] = Field(None, description="Document filters used to create the graph")
filters: Optional[Dict[str, Any]] = Field(
None, description="Document filters used to create the graph"
)
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
owner: Dict[str, str] = Field(default_factory=dict, description="Graph owner information")

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "morphik"
version = "0.1.0"
version = "0.1.2"
authors = [
{ name = "Morphik", email = "founders@morphik.ai" },
]