mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add folders and user scopes (#82)
This commit is contained in:
parent
1f3df392da
commit
75556c924a
@ -1,6 +1,5 @@
|
||||
JWT_SECRET_KEY="..." # Required in production, optional in dev mode (dev_mode=true in morphik.toml)
|
||||
POSTGRES_URI="postgresql+asyncpg://postgres:postgres@localhost:5432/morphik" # Required for PostgreSQL database
|
||||
MONGODB_URI="..." # Optional: Only needed if using MongoDB
|
||||
|
||||
UNSTRUCTURED_API_KEY="..." # Optional: Needed for parsing via unstructured API
|
||||
OPENAI_API_KEY="..." # Optional: Needed for OpenAI embeddings and completions
|
||||
|
@ -46,7 +46,7 @@ Built for scale and performance, Morphik can handle millions of documents while
|
||||
- 🧩 **Extensible Architecture**
|
||||
- Support for custom parsers and embedding models
|
||||
- Multiple storage backends (S3, local)
|
||||
- Vector store integrations (PostgreSQL/pgvector, MongoDB)
|
||||
- Vector store integration with PostgreSQL/pgvector
|
||||
|
||||
## Quick Start
|
||||
|
||||
@ -162,7 +162,7 @@ for chunk in chunks:
|
||||
| **Knowledge Graphs** | ✅ Automated extraction & enhanced retrieval | ❌ | ❌ | ❌ |
|
||||
| **Rules Engine** | ✅ Natural language rules & schema definition | ❌ | ❌ | Limited |
|
||||
| **Caching** | ✅ Persistent KV-caching with selective updates | ❌ | ❌ | Limited |
|
||||
| **Scalability** | ✅ Millions of documents with PostgreSQL/MongoDB | ✅ | ✅ | Limited |
|
||||
| **Scalability** | ✅ Millions of documents with PostgreSQL | ✅ | ✅ | Limited |
|
||||
| **Video Content** | ✅ Native video parsing & transcription | ❌ | ❌ | ❌ |
|
||||
| **Deployment Options** | ✅ Self-hosted, cloud, or hybrid | Varies | Varies | Limited |
|
||||
| **Open Source** | ✅ MIT License | Varies | Varies | Varies |
|
||||
@ -176,7 +176,7 @@ for chunk in chunks:
|
||||
|
||||
- **Schema-like Rules for Unstructured Data**: Define rules to extract consistent metadata from unstructured content, bringing database-like queryability to any document format.
|
||||
|
||||
- **Enterprise-grade Scalability**: Built on proven database technologies (PostgreSQL/MongoDB) that can scale to millions of documents while maintaining sub-second retrieval times.
|
||||
- **Enterprise-grade Scalability**: Built on proven PostgreSQL database technology that can scale to millions of documents while maintaining sub-second retrieval times.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
338
core/api.py
338
core/api.py
@ -20,9 +20,7 @@ from core.parser.morphik_parser import MorphikParser
|
||||
from core.services.document_service import DocumentService
|
||||
from core.services.telemetry import TelemetryService
|
||||
from core.config import get_settings
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
from core.database.postgres_database import PostgresDatabase
|
||||
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
from core.vector_store.multi_vector_store import MultiVectorStore
|
||||
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
|
||||
from core.storage.s3_storage import S3Storage
|
||||
@ -77,21 +75,9 @@ app.add_middleware(
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize database
|
||||
match settings.DATABASE_PROVIDER:
|
||||
case "postgres":
|
||||
if not settings.POSTGRES_URI:
|
||||
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
|
||||
database = PostgresDatabase(uri=settings.POSTGRES_URI)
|
||||
case "mongodb":
|
||||
if not settings.MONGODB_URI:
|
||||
raise ValueError("MongoDB URI is required for MongoDB database")
|
||||
database = MongoDatabase(
|
||||
uri=settings.MONGODB_URI,
|
||||
db_name=settings.DATABRIDGE_DB,
|
||||
collection_name=settings.DOCUMENTS_COLLECTION,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported database provider: {settings.DATABASE_PROVIDER}")
|
||||
if not settings.POSTGRES_URI:
|
||||
raise ValueError("PostgreSQL URI is required for PostgreSQL database")
|
||||
database = PostgresDatabase(uri=settings.POSTGRES_URI)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@ -144,24 +130,13 @@ async def initialize_user_limits_database():
|
||||
await user_limits_db.initialize()
|
||||
|
||||
# Initialize vector store
|
||||
match settings.VECTOR_STORE_PROVIDER:
|
||||
case "mongodb":
|
||||
vector_store = MongoDBAtlasVectorStore(
|
||||
uri=settings.MONGODB_URI,
|
||||
database_name=settings.DATABRIDGE_DB,
|
||||
collection_name=settings.CHUNKS_COLLECTION,
|
||||
index_name=settings.VECTOR_INDEX_NAME,
|
||||
)
|
||||
case "pgvector":
|
||||
if not settings.POSTGRES_URI:
|
||||
raise ValueError("PostgreSQL URI is required for pgvector store")
|
||||
from core.vector_store.pgvector_store import PGVectorStore
|
||||
if not settings.POSTGRES_URI:
|
||||
raise ValueError("PostgreSQL URI is required for pgvector store")
|
||||
from core.vector_store.pgvector_store import PGVectorStore
|
||||
|
||||
vector_store = PGVectorStore(
|
||||
uri=settings.POSTGRES_URI,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
|
||||
vector_store = PGVectorStore(
|
||||
uri=settings.POSTGRES_URI,
|
||||
)
|
||||
|
||||
# Initialize storage
|
||||
match settings.STORAGE_PROVIDER:
|
||||
@ -310,6 +285,8 @@ async def ingest_text(
|
||||
- rules: Optional list of rules. Each rule should be either:
|
||||
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
|
||||
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
|
||||
- folder_name: Optional folder to scope the document to
|
||||
- end_user_id: Optional end-user ID to scope the document to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
@ -324,6 +301,8 @@ async def ingest_text(
|
||||
"metadata": request.metadata,
|
||||
"rules": request.rules,
|
||||
"use_colpali": request.use_colpali,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.ingest_text(
|
||||
@ -333,6 +312,8 @@ async def ingest_text(
|
||||
rules=request.rules,
|
||||
use_colpali=request.use_colpali,
|
||||
auth=auth,
|
||||
folder_name=request.folder_name,
|
||||
end_user_id=request.end_user_id,
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
@ -345,6 +326,8 @@ async def ingest_file(
|
||||
rules: str = Form("[]"),
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
use_colpali: Optional[bool] = None,
|
||||
folder_name: Optional[str] = Form(None),
|
||||
end_user_id: Optional[str] = Form(None),
|
||||
) -> Document:
|
||||
"""
|
||||
Ingest a file document.
|
||||
@ -356,6 +339,9 @@ async def ingest_file(
|
||||
- MetadataExtractionRule: {"type": "metadata_extraction", "schema": {...}}
|
||||
- NaturalLanguageRule: {"type": "natural_language", "prompt": "..."}
|
||||
auth: Authentication context
|
||||
use_colpali: Whether to use ColPali embedding model
|
||||
folder_name: Optional folder to scope the document to
|
||||
end_user_id: Optional end-user ID to scope the document to
|
||||
|
||||
Returns:
|
||||
Document: Metadata of ingested document
|
||||
@ -374,15 +360,20 @@ async def ingest_file(
|
||||
"metadata": metadata_dict,
|
||||
"rules": rules_list,
|
||||
"use_colpali": use_colpali,
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id,
|
||||
},
|
||||
):
|
||||
logger.debug(f"API: Ingesting file with use_colpali: {use_colpali}")
|
||||
|
||||
return await document_service.ingest_file(
|
||||
file=file,
|
||||
metadata=metadata_dict,
|
||||
auth=auth,
|
||||
rules=rules_list,
|
||||
use_colpali=use_colpali,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
|
||||
@ -397,6 +388,8 @@ async def batch_ingest_files(
|
||||
rules: str = Form("[]"),
|
||||
use_colpali: Optional[bool] = Form(None),
|
||||
parallel: bool = Form(True),
|
||||
folder_name: Optional[str] = Form(None),
|
||||
end_user_id: Optional[str] = Form(None),
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
) -> BatchIngestResponse:
|
||||
"""
|
||||
@ -410,6 +403,8 @@ async def batch_ingest_files(
|
||||
- A list of rule lists, one per file
|
||||
use_colpali: Whether to use ColPali-style embedding
|
||||
parallel: Whether to process files in parallel
|
||||
folder_name: Optional folder to scope the documents to
|
||||
end_user_id: Optional end-user ID to scope the documents to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
@ -447,6 +442,8 @@ async def batch_ingest_files(
|
||||
documents = []
|
||||
errors = []
|
||||
|
||||
# We'll pass folder_name and end_user_id directly to the ingest_file functions
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="batch_ingest",
|
||||
user_id=auth.entity_id,
|
||||
@ -454,6 +451,8 @@ async def batch_ingest_files(
|
||||
"file_count": len(files),
|
||||
"metadata_type": "list" if isinstance(metadata_value, list) else "single",
|
||||
"rules_type": "per_file" if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else "shared",
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id,
|
||||
},
|
||||
):
|
||||
if parallel:
|
||||
@ -466,7 +465,9 @@ async def batch_ingest_files(
|
||||
metadata=metadata_item,
|
||||
auth=auth,
|
||||
rules=file_rules,
|
||||
use_colpali=use_colpali
|
||||
use_colpali=use_colpali,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@ -490,7 +491,9 @@ async def batch_ingest_files(
|
||||
metadata=metadata_item,
|
||||
auth=auth,
|
||||
rules=file_rules,
|
||||
use_colpali=use_colpali
|
||||
use_colpali=use_colpali,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
@ -504,7 +507,24 @@ async def batch_ingest_files(
|
||||
|
||||
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
||||
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve relevant chunks."""
|
||||
"""
|
||||
Retrieve relevant chunks.
|
||||
|
||||
Args:
|
||||
request: RetrieveRequest containing:
|
||||
- query: Search query text
|
||||
- filters: Optional metadata filters
|
||||
- k: Number of results (default: 4)
|
||||
- min_score: Minimum similarity threshold (default: 0.0)
|
||||
- use_reranking: Whether to use reranking
|
||||
- use_colpali: Whether to use ColPali-style embedding model
|
||||
- folder_name: Optional folder to scope the search to
|
||||
- end_user_id: Optional end-user ID to scope the search to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List[ChunkResult]: List of relevant chunks
|
||||
"""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="retrieve_chunks",
|
||||
@ -514,6 +534,8 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
|
||||
"min_score": request.min_score,
|
||||
"use_reranking": request.use_reranking,
|
||||
"use_colpali": request.use_colpali,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.retrieve_chunks(
|
||||
@ -524,6 +546,8 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
|
||||
request.min_score,
|
||||
request.use_reranking,
|
||||
request.use_colpali,
|
||||
request.folder_name,
|
||||
request.end_user_id,
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
@ -531,7 +555,24 @@ async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(
|
||||
|
||||
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
||||
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve relevant documents."""
|
||||
"""
|
||||
Retrieve relevant documents.
|
||||
|
||||
Args:
|
||||
request: RetrieveRequest containing:
|
||||
- query: Search query text
|
||||
- filters: Optional metadata filters
|
||||
- k: Number of results (default: 4)
|
||||
- min_score: Minimum similarity threshold (default: 0.0)
|
||||
- use_reranking: Whether to use reranking
|
||||
- use_colpali: Whether to use ColPali-style embedding model
|
||||
- folder_name: Optional folder to scope the search to
|
||||
- end_user_id: Optional end-user ID to scope the search to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List[DocumentResult]: List of relevant documents
|
||||
"""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="retrieve_docs",
|
||||
@ -541,6 +582,8 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
|
||||
"min_score": request.min_score,
|
||||
"use_reranking": request.use_reranking,
|
||||
"use_colpali": request.use_colpali,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.retrieve_docs(
|
||||
@ -551,39 +594,99 @@ async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depen
|
||||
request.min_score,
|
||||
request.use_reranking,
|
||||
request.use_colpali,
|
||||
request.folder_name,
|
||||
request.end_user_id,
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch/documents", response_model=List[Document])
|
||||
async def batch_get_documents(document_ids: List[str], auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve multiple documents by their IDs in a single batch operation."""
|
||||
async def batch_get_documents(
|
||||
request: Dict[str, Any],
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing:
|
||||
- document_ids: List of document IDs to retrieve
|
||||
- folder_name: Optional folder to scope the operation to
|
||||
- end_user_id: Optional end-user ID to scope the operation to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List[Document]: List of documents matching the IDs
|
||||
"""
|
||||
try:
|
||||
# Extract document_ids from request
|
||||
document_ids = request.get("document_ids", [])
|
||||
folder_name = request.get("folder_name")
|
||||
end_user_id = request.get("end_user_id")
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="batch_get_documents",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"document_count": len(document_ids),
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.batch_retrieve_documents(document_ids, auth)
|
||||
return await document_service.batch_retrieve_documents(document_ids, auth, folder_name, end_user_id)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/batch/chunks", response_model=List[ChunkResult])
|
||||
async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Depends(verify_token)):
|
||||
"""Retrieve specific chunks by their document ID and chunk number in a single batch operation."""
|
||||
async def batch_get_chunks(
|
||||
request: Dict[str, Any],
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
|
||||
Args:
|
||||
request: Dictionary containing:
|
||||
- sources: List of ChunkSource objects (with document_id and chunk_number)
|
||||
- folder_name: Optional folder to scope the operation to
|
||||
- end_user_id: Optional end-user ID to scope the operation to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List[ChunkResult]: List of chunk results
|
||||
"""
|
||||
try:
|
||||
# Extract sources from request
|
||||
sources = request.get("sources", [])
|
||||
folder_name = request.get("folder_name")
|
||||
end_user_id = request.get("end_user_id")
|
||||
|
||||
if not sources:
|
||||
return []
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="batch_get_chunks",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"chunk_count": len(chunk_ids),
|
||||
"chunk_count": len(sources),
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.batch_retrieve_chunks(chunk_ids, auth)
|
||||
# Convert sources to ChunkSource objects if needed
|
||||
chunk_sources = []
|
||||
for source in sources:
|
||||
if isinstance(source, dict):
|
||||
chunk_sources.append(ChunkSource(**source))
|
||||
else:
|
||||
chunk_sources.append(source)
|
||||
|
||||
return await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name, end_user_id)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
@ -592,10 +695,32 @@ async def batch_get_chunks(chunk_ids: List[ChunkSource], auth: AuthContext = Dep
|
||||
async def query_completion(
|
||||
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
|
||||
):
|
||||
"""Generate completion using relevant chunks as context.
|
||||
"""
|
||||
Generate completion using relevant chunks as context.
|
||||
|
||||
When graph_name is provided, the query will leverage the knowledge graph
|
||||
to enhance retrieval by finding relevant entities and their connected documents.
|
||||
|
||||
Args:
|
||||
request: CompletionQueryRequest containing:
|
||||
- query: Query text
|
||||
- filters: Optional metadata filters
|
||||
- k: Number of chunks to use as context (default: 4)
|
||||
- min_score: Minimum similarity threshold (default: 0.0)
|
||||
- max_tokens: Maximum tokens in completion
|
||||
- temperature: Model temperature
|
||||
- use_reranking: Whether to use reranking
|
||||
- use_colpali: Whether to use ColPali-style embedding model
|
||||
- graph_name: Optional name of the graph to use for knowledge graph-enhanced retrieval
|
||||
- hop_depth: Number of relationship hops to traverse in the graph (1-3)
|
||||
- include_paths: Whether to include relationship paths in the response
|
||||
- prompt_overrides: Optional customizations for entity extraction, resolution, and query prompts
|
||||
- folder_name: Optional folder to scope the operation to
|
||||
- end_user_id: Optional end-user ID to scope the operation to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
CompletionResponse: Generated completion
|
||||
"""
|
||||
try:
|
||||
# Validate prompt overrides before proceeding
|
||||
@ -620,6 +745,8 @@ async def query_completion(
|
||||
"graph_name": request.graph_name,
|
||||
"hop_depth": request.hop_depth,
|
||||
"include_paths": request.include_paths,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
return await document_service.query(
|
||||
@ -636,6 +763,8 @@ async def query_completion(
|
||||
request.hop_depth,
|
||||
request.include_paths,
|
||||
request.prompt_overrides,
|
||||
request.folder_name,
|
||||
request.end_user_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
validate_prompt_overrides_with_http_exception(operation_type="query", error=e)
|
||||
@ -649,9 +778,31 @@ async def list_documents(
|
||||
skip: int = 0,
|
||||
limit: int = 10000,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
):
|
||||
"""List accessible documents."""
|
||||
return await document_service.db.get_documents(auth, skip, limit, filters)
|
||||
"""
|
||||
List accessible documents.
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
skip: Number of documents to skip
|
||||
limit: Maximum number of documents to return
|
||||
filters: Optional metadata filters
|
||||
folder_name: Optional folder to scope the operation to
|
||||
end_user_id: Optional end-user ID to scope the operation to
|
||||
|
||||
Returns:
|
||||
List[Document]: List of accessible documents
|
||||
"""
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
return await document_service.db.get_documents(auth, skip, limit, filters, system_filters)
|
||||
|
||||
|
||||
@app.get("/documents/{document_id}", response_model=Document)
|
||||
@ -700,10 +851,33 @@ async def delete_document(document_id: str, auth: AuthContext = Depends(verify_t
|
||||
|
||||
|
||||
@app.get("/documents/filename/{filename}", response_model=Document)
|
||||
async def get_document_by_filename(filename: str, auth: AuthContext = Depends(verify_token)):
|
||||
"""Get document by filename."""
|
||||
async def get_document_by_filename(
|
||||
filename: str,
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Get document by filename.
|
||||
|
||||
Args:
|
||||
filename: Filename of the document to retrieve
|
||||
auth: Authentication context
|
||||
folder_name: Optional folder to scope the operation to
|
||||
end_user_id: Optional end-user ID to scope the operation to
|
||||
|
||||
Returns:
|
||||
Document: Document metadata if found and accessible
|
||||
"""
|
||||
try:
|
||||
doc = await document_service.db.get_document_by_filename(filename, auth)
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
doc = await document_service.db.get_document_by_filename(filename, auth, system_filters)
|
||||
logger.debug(f"Found document by filename: {doc}")
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail=f"Document with filename '{filename}' not found")
|
||||
@ -1071,6 +1245,9 @@ async def create_graph(
|
||||
- name: Name of the graph to create
|
||||
- filters: Optional metadata filters to determine which documents to include
|
||||
- documents: Optional list of specific document IDs to include
|
||||
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
|
||||
- folder_name: Optional folder to scope the operation to
|
||||
- end_user_id: Optional end-user ID to scope the operation to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
@ -1093,14 +1270,24 @@ async def create_graph(
|
||||
"name": request.name,
|
||||
"filters": request.filters,
|
||||
"documents": request.documents,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if request.folder_name:
|
||||
system_filters["folder_name"] = request.folder_name
|
||||
if request.end_user_id:
|
||||
system_filters["end_user_id"] = request.end_user_id
|
||||
|
||||
return await document_service.create_graph(
|
||||
name=request.name,
|
||||
auth=auth,
|
||||
filters=request.filters,
|
||||
documents=request.documents,
|
||||
prompt_overrides=request.prompt_overrides,
|
||||
system_filters=system_filters,
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
@ -1112,6 +1299,8 @@ async def create_graph(
|
||||
async def get_graph(
|
||||
name: str,
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Get a graph by name.
|
||||
@ -1121,6 +1310,8 @@ async def get_graph(
|
||||
Args:
|
||||
name: Name of the graph to retrieve
|
||||
auth: Authentication context
|
||||
folder_name: Optional folder to scope the operation to
|
||||
end_user_id: Optional end-user ID to scope the operation to
|
||||
|
||||
Returns:
|
||||
Graph: The requested graph object
|
||||
@ -1129,9 +1320,20 @@ async def get_graph(
|
||||
async with telemetry.track_operation(
|
||||
operation_type="get_graph",
|
||||
user_id=auth.entity_id,
|
||||
metadata={"name": name},
|
||||
metadata={
|
||||
"name": name,
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id
|
||||
},
|
||||
):
|
||||
graph = await document_service.db.get_graph(name, auth)
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
graph = await document_service.db.get_graph(name, auth, system_filters)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph '{name}' not found")
|
||||
return graph
|
||||
@ -1144,6 +1346,8 @@ async def get_graph(
|
||||
@app.get("/graphs", response_model=List[Graph])
|
||||
async def list_graphs(
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> List[Graph]:
|
||||
"""
|
||||
List all graphs the user has access to.
|
||||
@ -1152,6 +1356,8 @@ async def list_graphs(
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
folder_name: Optional folder to scope the operation to
|
||||
end_user_id: Optional end-user ID to scope the operation to
|
||||
|
||||
Returns:
|
||||
List[Graph]: List of graph objects
|
||||
@ -1160,8 +1366,19 @@ async def list_graphs(
|
||||
async with telemetry.track_operation(
|
||||
operation_type="list_graphs",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id
|
||||
},
|
||||
):
|
||||
return await document_service.db.list_graphs(auth)
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
return await document_service.db.list_graphs(auth, system_filters)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
@ -1186,6 +1403,9 @@ async def update_graph(
|
||||
request: UpdateGraphRequest containing:
|
||||
- additional_filters: Optional additional metadata filters to determine which new documents to include
|
||||
- additional_documents: Optional list of additional document IDs to include
|
||||
- prompt_overrides: Optional customizations for entity extraction and resolution prompts
|
||||
- folder_name: Optional folder to scope the operation to
|
||||
- end_user_id: Optional end-user ID to scope the operation to
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
@ -1203,14 +1423,24 @@ async def update_graph(
|
||||
"name": name,
|
||||
"additional_filters": request.additional_filters,
|
||||
"additional_documents": request.additional_documents,
|
||||
"folder_name": request.folder_name,
|
||||
"end_user_id": request.end_user_id,
|
||||
},
|
||||
):
|
||||
# Create system filters for folder and user scoping
|
||||
system_filters = {}
|
||||
if request.folder_name:
|
||||
system_filters["folder_name"] = request.folder_name
|
||||
if request.end_user_id:
|
||||
system_filters["end_user_id"] = request.end_user_id
|
||||
|
||||
return await document_service.update_graph(
|
||||
name=name,
|
||||
auth=auth,
|
||||
additional_filters=request.additional_filters,
|
||||
additional_documents=request.additional_documents,
|
||||
prompt_overrides=request.prompt_overrides,
|
||||
system_filters=system_filters,
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
@ -13,7 +13,6 @@ class Settings(BaseSettings):
|
||||
# Environment variables
|
||||
JWT_SECRET_KEY: str
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
MONGODB_URI: Optional[str] = None
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = None
|
||||
AWS_ACCESS_KEY: Optional[str] = None
|
||||
AWS_SECRET_ACCESS_KEY: Optional[str] = None
|
||||
@ -42,9 +41,8 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
# Database configuration
|
||||
DATABASE_PROVIDER: Literal["postgres", "mongodb"]
|
||||
DATABASE_PROVIDER: Literal["postgres"]
|
||||
DATABASE_NAME: Optional[str] = None
|
||||
DOCUMENTS_COLLECTION: Optional[str] = None
|
||||
|
||||
# Embedding configuration
|
||||
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
|
||||
@ -85,9 +83,8 @@ class Settings(BaseSettings):
|
||||
S3_BUCKET: Optional[str] = None
|
||||
|
||||
# Vector store configuration
|
||||
VECTOR_STORE_PROVIDER: Literal["pgvector", "mongodb"]
|
||||
VECTOR_STORE_PROVIDER: Literal["pgvector"]
|
||||
VECTOR_STORE_DATABASE_NAME: Optional[str] = None
|
||||
VECTOR_STORE_COLLECTION_NAME: Optional[str] = None
|
||||
|
||||
# Colpali configuration
|
||||
ENABLE_COLPALI: bool
|
||||
@ -164,24 +161,17 @@ def get_settings() -> Settings:
|
||||
|
||||
# load database config
|
||||
database_config = {"DATABASE_PROVIDER": config["database"]["provider"]}
|
||||
match database_config["DATABASE_PROVIDER"]:
|
||||
case "mongodb":
|
||||
database_config.update(
|
||||
{
|
||||
"DATABASE_NAME": config["database"]["database_name"],
|
||||
"COLLECTION_NAME": config["database"]["collection_name"],
|
||||
}
|
||||
)
|
||||
case "postgres" if "POSTGRES_URI" in os.environ:
|
||||
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
|
||||
case "postgres":
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = database_config["DATABASE_PROVIDER"]
|
||||
raise ValueError(f"Unknown database provider selected: '{prov}'")
|
||||
if database_config["DATABASE_PROVIDER"] != "postgres":
|
||||
prov = database_config["DATABASE_PROVIDER"]
|
||||
raise ValueError(f"Unknown database provider selected: '{prov}'")
|
||||
|
||||
if "POSTGRES_URI" in os.environ:
|
||||
database_config.update({"POSTGRES_URI": os.environ["POSTGRES_URI"]})
|
||||
else:
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="database.provider", value="postgres"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# load embedding config
|
||||
embedding_config = {
|
||||
@ -251,23 +241,15 @@ def get_settings() -> Settings:
|
||||
|
||||
# load vector store config
|
||||
vector_store_config = {"VECTOR_STORE_PROVIDER": config["vector_store"]["provider"]}
|
||||
match vector_store_config["VECTOR_STORE_PROVIDER"]:
|
||||
case "mongodb":
|
||||
vector_store_config.update(
|
||||
{
|
||||
"VECTOR_STORE_DATABASE_NAME": config["vector_store"]["database_name"],
|
||||
"VECTOR_STORE_COLLECTION_NAME": config["vector_store"]["collection_name"],
|
||||
}
|
||||
)
|
||||
case "pgvector":
|
||||
if "POSTGRES_URI" not in os.environ:
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
case _:
|
||||
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
|
||||
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
|
||||
if vector_store_config["VECTOR_STORE_PROVIDER"] != "pgvector":
|
||||
prov = vector_store_config["VECTOR_STORE_PROVIDER"]
|
||||
raise ValueError(f"Unknown vector store provider selected: '{prov}'")
|
||||
|
||||
if "POSTGRES_URI" not in os.environ:
|
||||
msg = em.format(
|
||||
missing_value="POSTGRES_URI", field="vector_store.provider", value="pgvector"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# load rules config
|
||||
rules_config = {
|
||||
|
@ -26,7 +26,7 @@ class BaseDatabase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
|
||||
async def get_document_by_filename(self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Document]:
|
||||
"""
|
||||
Retrieve document metadata by filename if user has access.
|
||||
If multiple documents have the same filename, returns the most recently updated one.
|
||||
@ -34,6 +34,7 @@ class BaseDatabase(ABC):
|
||||
Args:
|
||||
filename: The filename to search for
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
Document if found and accessible, None otherwise
|
||||
@ -41,14 +42,16 @@ class BaseDatabase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
Can filter by system metadata fields like folder_name and end_user_id.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
system_filters: Optional filters for system metadata fields
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
@ -62,10 +65,21 @@ class BaseDatabase(ABC):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
List documents the user has access to.
|
||||
Supports pagination and filtering.
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
skip: Number of documents to skip (for pagination)
|
||||
limit: Maximum number of documents to return
|
||||
filters: Optional metadata filters
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
List of documents matching the criteria
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -89,9 +103,18 @@ class BaseDatabase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def find_authorized_and_filtered_documents(
|
||||
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
|
||||
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[str]:
|
||||
"""Find document IDs matching filters that user has access to."""
|
||||
"""Find document IDs matching filters that user has access to.
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
filters: Optional metadata filters
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
List of document IDs matching the criteria
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -142,12 +165,13 @@ class BaseDatabase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_graph(self, name: str, auth: AuthContext) -> Optional[Graph]:
|
||||
async def get_graph(self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Graph]:
|
||||
"""Get a graph by name.
|
||||
|
||||
Args:
|
||||
name: Name of the graph
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
Optional[Graph]: Graph if found and accessible, None otherwise
|
||||
@ -155,11 +179,12 @@ class BaseDatabase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_graphs(self, auth: AuthContext) -> List[Graph]:
|
||||
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
|
||||
"""List all graphs the user has access to.
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
List[Graph]: List of graphs
|
||||
|
@ -1,329 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import ReturnDocument
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
from .base_database import BaseDatabase
|
||||
from ..models.documents import Document
|
||||
from ..models.auth import AuthContext, EntityType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDatabase(BaseDatabase):
|
||||
"""MongoDB implementation for document metadata storage."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
):
|
||||
"""Initialize MongoDB connection for document storage."""
|
||||
self.client = AsyncIOMotorClient(uri)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.db[collection_name]
|
||||
self.caches = self.db["caches"] # Collection for cache metadata
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database indexes."""
|
||||
try:
|
||||
# Create indexes for common queries
|
||||
await self.collection.create_index("external_id", unique=True)
|
||||
await self.collection.create_index("owner.id")
|
||||
await self.collection.create_index("access_control.readers")
|
||||
await self.collection.create_index("access_control.writers")
|
||||
await self.collection.create_index("access_control.admins")
|
||||
await self.collection.create_index("system_metadata.created_at")
|
||||
|
||||
logger.info("MongoDB indexes created successfully")
|
||||
return True
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error creating MongoDB indexes: {str(e)}")
|
||||
return False
|
||||
|
||||
async def store_document(self, document: Document) -> bool:
|
||||
"""Store document metadata."""
|
||||
try:
|
||||
doc_dict = document.model_dump()
|
||||
|
||||
# Ensure system metadata
|
||||
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
|
||||
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
|
||||
doc_dict["metadata"]["external_id"] = doc_dict["external_id"]
|
||||
|
||||
result = await self.collection.insert_one(doc_dict)
|
||||
return bool(result.inserted_id)
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error storing document metadata: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
|
||||
"""Retrieve document metadata by ID if user has access."""
|
||||
try:
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query document
|
||||
query = {"$and": [{"external_id": document_id}, access_filter]}
|
||||
logger.debug(f"Querying document with query: {query}")
|
||||
|
||||
doc_dict = await self.collection.find_one(query)
|
||||
logger.debug(f"Found document: {doc_dict}")
|
||||
return Document(**doc_dict) if doc_dict else None
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document metadata: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
|
||||
"""Retrieve document metadata by filename if user has access.
|
||||
If multiple documents have the same filename, returns the most recently updated one.
|
||||
"""
|
||||
try:
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query document
|
||||
query = {"$and": [{"filename": filename}, access_filter]}
|
||||
logger.debug(f"Querying document by filename with query: {query}")
|
||||
|
||||
# Sort by updated_at in descending order to get the most recent one
|
||||
sort_criteria = [("system_metadata.updated_at", -1)]
|
||||
|
||||
doc_dict = await self.collection.find_one(query, sort=sort_criteria)
|
||||
logger.debug(f"Found document by filename: {doc_dict}")
|
||||
|
||||
return Document(**doc_dict) if doc_dict else None
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
"""
|
||||
try:
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query documents with both document IDs and access check in a single query
|
||||
query = {
|
||||
"$and": [
|
||||
{"external_id": {"$in": document_ids}},
|
||||
access_filter
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
|
||||
|
||||
# Execute batch query
|
||||
cursor = self.collection.find(query)
|
||||
|
||||
documents = []
|
||||
async for doc_dict in cursor:
|
||||
documents.append(Document(**doc_dict))
|
||||
|
||||
logger.info(f"Found {len(documents)} documents in batch retrieval")
|
||||
return documents
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error batch retrieving documents: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
auth: AuthContext,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
"""List accessible documents with pagination and filtering."""
|
||||
try:
|
||||
# Build query
|
||||
auth_filter = self._build_access_filter(auth)
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
|
||||
|
||||
# Execute paginated query
|
||||
cursor = self.collection.find(query).skip(skip).limit(limit)
|
||||
|
||||
documents = []
|
||||
async for doc_dict in cursor:
|
||||
documents.append(Document(**doc_dict))
|
||||
|
||||
return documents
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error listing documents: {str(e)}")
|
||||
return []
|
||||
|
||||
async def update_document(
|
||||
self, document_id: str, updates: Dict[str, Any], auth: AuthContext
|
||||
) -> bool:
|
||||
"""Update document metadata if user has write access."""
|
||||
try:
|
||||
# Verify write access
|
||||
if not await self.check_access(document_id, auth, "write"):
|
||||
return False
|
||||
|
||||
# Update system metadata
|
||||
updates.setdefault("system_metadata", {})
|
||||
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
|
||||
|
||||
result = await self.collection.find_one_and_update(
|
||||
{"external_id": document_id},
|
||||
{"$set": updates},
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
return bool(result)
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error updating document metadata: {str(e)}")
|
||||
return False
|
||||
|
||||
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
|
||||
"""Delete document if user has admin access."""
|
||||
try:
|
||||
# Verify admin access
|
||||
if not await self.check_access(document_id, auth, "admin"):
|
||||
return False
|
||||
|
||||
result = await self.collection.delete_one({"external_id": document_id})
|
||||
return bool(result.deleted_count)
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error deleting document: {str(e)}")
|
||||
return False
|
||||
|
||||
async def find_authorized_and_filtered_documents(
|
||||
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[str]:
|
||||
"""Find document IDs matching filters and access permissions."""
|
||||
# Build query
|
||||
auth_filter = self._build_access_filter(auth)
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
|
||||
|
||||
# Get matching document IDs
|
||||
cursor = self.collection.find(query, {"external_id": 1})
|
||||
|
||||
document_ids = []
|
||||
async for doc in cursor:
|
||||
document_ids.append(doc["external_id"])
|
||||
|
||||
return document_ids
|
||||
|
||||
async def check_access(
|
||||
self, document_id: str, auth: AuthContext, required_permission: str = "read"
|
||||
) -> bool:
|
||||
"""Check if user has required permission for document."""
|
||||
try:
|
||||
doc = await self.collection.find_one({"external_id": document_id})
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
access_control = doc.get("access_control", {})
|
||||
|
||||
# Check owner access
|
||||
owner = doc.get("owner", {})
|
||||
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
|
||||
return True
|
||||
|
||||
# Check permission-specific access
|
||||
permission_map = {"read": "readers", "write": "writers", "admin": "admins"}
|
||||
|
||||
permission_set = permission_map.get(required_permission)
|
||||
if not permission_set:
|
||||
return False
|
||||
|
||||
return auth.entity_id in access_control.get(permission_set, set())
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error checking document access: {str(e)}")
|
||||
return False
|
||||
|
||||
def _build_access_filter(self, auth: AuthContext) -> Dict[str, Any]:
|
||||
"""Build MongoDB filter for access control."""
|
||||
base_filter = {
|
||||
"$or": [
|
||||
{"owner.id": auth.entity_id},
|
||||
{"access_control.readers": auth.entity_id},
|
||||
{"access_control.writers": auth.entity_id},
|
||||
{"access_control.admins": auth.entity_id},
|
||||
]
|
||||
}
|
||||
|
||||
if auth.entity_type == EntityType.DEVELOPER:
|
||||
# Add app-specific access for developers
|
||||
base_filter["$or"].append({"access_control.app_access": auth.app_id})
|
||||
|
||||
return base_filter
|
||||
|
||||
def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build MongoDB filter for metadata."""
|
||||
if not filters:
|
||||
return {}
|
||||
filter_dict = {}
|
||||
for key, value in filters.items():
|
||||
filter_dict[f"metadata.{key}"] = value
|
||||
return filter_dict
|
||||
|
||||
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Store metadata for a cache in MongoDB.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
metadata: Cache metadata including model info and storage location
|
||||
|
||||
Returns:
|
||||
bool: Whether the operation was successful
|
||||
"""
|
||||
try:
|
||||
# Add timestamp and ensure name is included
|
||||
doc = {
|
||||
"name": name,
|
||||
"metadata": metadata,
|
||||
"created_at": datetime.now(UTC),
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
|
||||
# Upsert the document
|
||||
result = await self.caches.update_one({"name": name}, {"$set": doc}, upsert=True)
|
||||
return bool(result.modified_count or result.upserted_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store cache metadata: {e}")
|
||||
return False
|
||||
|
||||
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a cache from MongoDB.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc = await self.caches.find_one({"name": name})
|
||||
return doc["metadata"] if doc else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache metadata: {e}")
|
||||
return None
|
@ -51,6 +51,7 @@ class GraphModel(Base):
|
||||
entities = Column(JSONB, default=list)
|
||||
relationships = Column(JSONB, default=list)
|
||||
graph_metadata = Column(JSONB, default=dict) # Renamed from 'metadata' to avoid conflict
|
||||
system_metadata = Column(JSONB, default=dict) # For folder_name and end_user_id
|
||||
document_ids = Column(JSONB, default=list)
|
||||
filters = Column(JSONB, nullable=True)
|
||||
created_at = Column(String) # ISO format string
|
||||
@ -63,6 +64,7 @@ class GraphModel(Base):
|
||||
Index("idx_graph_name", "name"),
|
||||
Index("idx_graph_owner", "owner", postgresql_using="gin"),
|
||||
Index("idx_graph_access_control", "access_control", postgresql_using="gin"),
|
||||
Index("idx_graph_system_metadata", "system_metadata", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
|
||||
@ -139,6 +141,68 @@ class PostgresDatabase(BaseDatabase):
|
||||
)
|
||||
)
|
||||
logger.info("Added storage_files column to documents table")
|
||||
|
||||
# Create indexes for folder_name and end_user_id in system_metadata for documents
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metadata_folder_name
|
||||
ON documents ((system_metadata->>'folder_name'));
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_system_metadata_end_user_id
|
||||
ON documents ((system_metadata->>'end_user_id'));
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Check if system_metadata column exists in graphs table
|
||||
result = await conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'graphs' AND column_name = 'system_metadata'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.first():
|
||||
# Add system_metadata column to graphs table
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE graphs
|
||||
ADD COLUMN IF NOT EXISTS system_metadata JSONB DEFAULT '{}'::jsonb
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info("Added system_metadata column to graphs table")
|
||||
|
||||
# Create indexes for folder_name and end_user_id in system_metadata for graphs
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_folder_name
|
||||
ON graphs ((system_metadata->>'folder_name'));
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_system_metadata_end_user_id
|
||||
ON graphs ((system_metadata->>'end_user_id'));
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Created indexes for folder_name and end_user_id in system_metadata")
|
||||
|
||||
logger.info("PostgreSQL tables and indexes created successfully")
|
||||
self._initialized = True
|
||||
@ -221,24 +285,42 @@ class PostgresDatabase(BaseDatabase):
|
||||
logger.error(f"Error retrieving document metadata: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_document_by_filename(self, filename: str, auth: AuthContext) -> Optional[Document]:
|
||||
async def get_document_by_filename(self, filename: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Document]:
|
||||
"""Retrieve document metadata by filename if user has access.
|
||||
If multiple documents have the same filename, returns the most recently updated one.
|
||||
|
||||
Args:
|
||||
filename: The filename to search for
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
# Query document
|
||||
# Construct where clauses
|
||||
where_clauses = [
|
||||
f"({access_filter})",
|
||||
f"filename = '{filename.replace('\'', '\'\'')}'" # Escape single quotes
|
||||
]
|
||||
|
||||
if system_metadata_filter:
|
||||
where_clauses.append(f"({system_metadata_filter})")
|
||||
|
||||
final_where_clause = " AND ".join(where_clauses)
|
||||
|
||||
# Query document with system filters
|
||||
query = (
|
||||
select(DocumentModel)
|
||||
.where(DocumentModel.filename == filename)
|
||||
.where(text(f"({access_filter})"))
|
||||
.where(text(final_where_clause))
|
||||
# Order by updated_at in system_metadata to get the most recent document
|
||||
.order_by(text("system_metadata->>'updated_at' DESC"))
|
||||
)
|
||||
|
||||
logger.debug(f"Querying document by filename with system filters: {system_filters}")
|
||||
|
||||
result = await session.execute(query)
|
||||
doc_model = result.scalar_one_or_none()
|
||||
|
||||
@ -264,14 +346,16 @@ class PostgresDatabase(BaseDatabase):
|
||||
logger.error(f"Error retrieving document metadata by filename: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext) -> List[Document]:
|
||||
async def get_documents_by_id(self, document_ids: List[str], auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
Only returns documents the user has access to.
|
||||
Can filter by system metadata fields like folder_name and end_user_id.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to retrieve
|
||||
auth: Authentication context
|
||||
system_filters: Optional filters for system metadata fields
|
||||
|
||||
Returns:
|
||||
List of Document objects that were found and user has access to
|
||||
@ -283,13 +367,21 @@ class PostgresDatabase(BaseDatabase):
|
||||
async with self.async_session() as session:
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
# Query documents with both document IDs and access check in a single query
|
||||
query = (
|
||||
select(DocumentModel)
|
||||
.where(DocumentModel.external_id.in_(document_ids))
|
||||
.where(text(f"({access_filter})"))
|
||||
)
|
||||
# Construct where clauses
|
||||
where_clauses = [
|
||||
f"({access_filter})",
|
||||
f"external_id IN ({', '.join([f'\'{doc_id}\'' for doc_id in document_ids])})"
|
||||
]
|
||||
|
||||
if system_metadata_filter:
|
||||
where_clauses.append(f"({system_metadata_filter})")
|
||||
|
||||
final_where_clause = " AND ".join(where_clauses)
|
||||
|
||||
# Query documents with document IDs, access check, and system filters in a single query
|
||||
query = select(DocumentModel).where(text(final_where_clause))
|
||||
|
||||
logger.info(f"Batch retrieving {len(document_ids)} documents with a single query")
|
||||
|
||||
@ -328,6 +420,7 @@ class PostgresDatabase(BaseDatabase):
|
||||
skip: int = 0,
|
||||
limit: int = 10000,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
"""List documents the user has access to."""
|
||||
try:
|
||||
@ -335,10 +428,18 @@ class PostgresDatabase(BaseDatabase):
|
||||
# Build query
|
||||
access_filter = self._build_access_filter(auth)
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
query = select(DocumentModel).where(text(f"({access_filter})"))
|
||||
where_clauses = [f"({access_filter})"]
|
||||
|
||||
if metadata_filter:
|
||||
query = query.where(text(metadata_filter))
|
||||
where_clauses.append(f"({metadata_filter})")
|
||||
|
||||
if system_metadata_filter:
|
||||
where_clauses.append(f"({system_metadata_filter})")
|
||||
|
||||
final_where_clause = " AND ".join(where_clauses)
|
||||
query = select(DocumentModel).where(text(final_where_clause))
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
@ -373,9 +474,23 @@ class PostgresDatabase(BaseDatabase):
|
||||
try:
|
||||
if not await self.check_access(document_id, auth, "write"):
|
||||
return False
|
||||
|
||||
# Get existing document to preserve system_metadata
|
||||
existing_doc = await self.get_document(document_id, auth)
|
||||
if not existing_doc:
|
||||
return False
|
||||
|
||||
# Update system metadata
|
||||
updates.setdefault("system_metadata", {})
|
||||
|
||||
# Preserve folder_name and end_user_id if not explicitly overridden
|
||||
if existing_doc.system_metadata:
|
||||
if "folder_name" in existing_doc.system_metadata and "folder_name" not in updates["system_metadata"]:
|
||||
updates["system_metadata"]["folder_name"] = existing_doc.system_metadata["folder_name"]
|
||||
|
||||
if "end_user_id" in existing_doc.system_metadata and "end_user_id" not in updates["system_metadata"]:
|
||||
updates["system_metadata"]["end_user_id"] = existing_doc.system_metadata["end_user_id"]
|
||||
|
||||
updates["system_metadata"]["updated_at"] = datetime.now(UTC)
|
||||
|
||||
# Serialize datetime objects to ISO format strings
|
||||
@ -421,7 +536,7 @@ class PostgresDatabase(BaseDatabase):
|
||||
return False
|
||||
|
||||
async def find_authorized_and_filtered_documents(
|
||||
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None
|
||||
self, auth: AuthContext, filters: Optional[Dict[str, Any]] = None, system_filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[str]:
|
||||
"""Find document IDs matching filters and access permissions."""
|
||||
try:
|
||||
@ -429,14 +544,24 @@ class PostgresDatabase(BaseDatabase):
|
||||
# Build query
|
||||
access_filter = self._build_access_filter(auth)
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
logger.debug(f"Access filter: {access_filter}")
|
||||
logger.debug(f"Metadata filter: {metadata_filter}")
|
||||
logger.debug(f"System metadata filter: {system_metadata_filter}")
|
||||
logger.debug(f"Original filters: {filters}")
|
||||
logger.debug(f"System filters: {system_filters}")
|
||||
|
||||
query = select(DocumentModel.external_id).where(text(f"({access_filter})"))
|
||||
where_clauses = [f"({access_filter})"]
|
||||
|
||||
if metadata_filter:
|
||||
query = query.where(text(metadata_filter))
|
||||
where_clauses.append(f"({metadata_filter})")
|
||||
|
||||
if system_metadata_filter:
|
||||
where_clauses.append(f"({system_metadata_filter})")
|
||||
|
||||
final_where_clause = " AND ".join(where_clauses)
|
||||
query = select(DocumentModel.external_id).where(text(final_where_clause))
|
||||
|
||||
logger.debug(f"Final query: {query}")
|
||||
|
||||
@ -525,6 +650,25 @@ class PostgresDatabase(BaseDatabase):
|
||||
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
|
||||
|
||||
return " AND ".join(filter_conditions)
|
||||
|
||||
def _build_system_metadata_filter(self, system_filters: Optional[Dict[str, Any]]) -> str:
|
||||
"""Build PostgreSQL filter for system metadata."""
|
||||
if not system_filters:
|
||||
return ""
|
||||
|
||||
conditions = []
|
||||
for key, value in system_filters.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, str):
|
||||
# Replace single quotes with double single quotes to escape them
|
||||
escaped_value = value.replace("'", "''")
|
||||
conditions.append(f"system_metadata->>'{key}' = '{escaped_value}'")
|
||||
else:
|
||||
conditions.append(f"system_metadata->>'{key}' = '{value}'")
|
||||
|
||||
return " AND ".join(conditions)
|
||||
|
||||
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Store metadata for a cache in PostgreSQL.
|
||||
@ -618,12 +762,13 @@ class PostgresDatabase(BaseDatabase):
|
||||
logger.error(f"Error storing graph: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_graph(self, name: str, auth: AuthContext) -> Optional[Graph]:
|
||||
async def get_graph(self, name: str, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> Optional[Graph]:
|
||||
"""Get a graph by name.
|
||||
|
||||
Args:
|
||||
name: Name of the graph
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
Optional[Graph]: Graph if found and accessible, None otherwise
|
||||
@ -637,7 +782,8 @@ class PostgresDatabase(BaseDatabase):
|
||||
# Build access filter
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Query graph
|
||||
# We need to check if the documents in the graph match the system filters
|
||||
# First get the graph without system filters
|
||||
query = (
|
||||
select(GraphModel)
|
||||
.where(GraphModel.name == name)
|
||||
@ -648,6 +794,32 @@ class PostgresDatabase(BaseDatabase):
|
||||
graph_model = result.scalar_one_or_none()
|
||||
|
||||
if graph_model:
|
||||
# If system filters are provided, we need to filter the document_ids
|
||||
document_ids = graph_model.document_ids
|
||||
|
||||
if system_filters and document_ids:
|
||||
# Apply system_filters to document_ids
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
if system_metadata_filter:
|
||||
# Get document IDs with system filters
|
||||
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
|
||||
filter_query = f"""
|
||||
SELECT external_id FROM documents
|
||||
WHERE external_id IN ({doc_id_placeholders})
|
||||
AND ({system_metadata_filter})
|
||||
"""
|
||||
|
||||
filter_result = await session.execute(text(filter_query))
|
||||
filtered_doc_ids = [row[0] for row in filter_result.all()]
|
||||
|
||||
# If no documents match system filters, return None
|
||||
if not filtered_doc_ids:
|
||||
return None
|
||||
|
||||
# Update document_ids with filtered results
|
||||
document_ids = filtered_doc_ids
|
||||
|
||||
# Convert to Graph model
|
||||
graph_dict = {
|
||||
"id": graph_model.id,
|
||||
@ -655,7 +827,8 @@ class PostgresDatabase(BaseDatabase):
|
||||
"entities": graph_model.entities,
|
||||
"relationships": graph_model.relationships,
|
||||
"metadata": graph_model.graph_metadata, # Reference the renamed column
|
||||
"document_ids": graph_model.document_ids,
|
||||
"system_metadata": graph_model.system_metadata or {}, # Include system_metadata
|
||||
"document_ids": document_ids, # Use possibly filtered document_ids
|
||||
"filters": graph_model.filters,
|
||||
"created_at": graph_model.created_at,
|
||||
"updated_at": graph_model.updated_at,
|
||||
@ -670,11 +843,12 @@ class PostgresDatabase(BaseDatabase):
|
||||
logger.error(f"Error retrieving graph: {str(e)}")
|
||||
return None
|
||||
|
||||
async def list_graphs(self, auth: AuthContext) -> List[Graph]:
|
||||
async def list_graphs(self, auth: AuthContext, system_filters: Optional[Dict[str, Any]] = None) -> List[Graph]:
|
||||
"""List all graphs the user has access to.
|
||||
|
||||
Args:
|
||||
auth: Authentication context
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id)
|
||||
|
||||
Returns:
|
||||
List[Graph]: List of graphs
|
||||
@ -693,23 +867,66 @@ class PostgresDatabase(BaseDatabase):
|
||||
|
||||
result = await session.execute(query)
|
||||
graph_models = result.scalars().all()
|
||||
|
||||
return [
|
||||
Graph(
|
||||
id=graph.id,
|
||||
name=graph.name,
|
||||
entities=graph.entities,
|
||||
relationships=graph.relationships,
|
||||
metadata=graph.graph_metadata, # Reference the renamed column
|
||||
document_ids=graph.document_ids,
|
||||
filters=graph.filters,
|
||||
created_at=graph.created_at,
|
||||
updated_at=graph.updated_at,
|
||||
owner=graph.owner,
|
||||
access_control=graph.access_control,
|
||||
)
|
||||
for graph in graph_models
|
||||
]
|
||||
|
||||
graphs = []
|
||||
|
||||
# If system filters are provided, we need to filter each graph's document_ids
|
||||
if system_filters:
|
||||
system_metadata_filter = self._build_system_metadata_filter(system_filters)
|
||||
|
||||
for graph_model in graph_models:
|
||||
document_ids = graph_model.document_ids
|
||||
|
||||
if document_ids and system_metadata_filter:
|
||||
# Get document IDs with system filters
|
||||
doc_id_placeholders = ", ".join([f"'{doc_id}'" for doc_id in document_ids])
|
||||
filter_query = f"""
|
||||
SELECT external_id FROM documents
|
||||
WHERE external_id IN ({doc_id_placeholders})
|
||||
AND ({system_metadata_filter})
|
||||
"""
|
||||
|
||||
filter_result = await session.execute(text(filter_query))
|
||||
filtered_doc_ids = [row[0] for row in filter_result.all()]
|
||||
|
||||
# Only include graphs that have documents matching the system filters
|
||||
if filtered_doc_ids:
|
||||
graph = Graph(
|
||||
id=graph_model.id,
|
||||
name=graph_model.name,
|
||||
entities=graph_model.entities,
|
||||
relationships=graph_model.relationships,
|
||||
metadata=graph_model.graph_metadata, # Reference the renamed column
|
||||
system_metadata=graph_model.system_metadata or {}, # Include system_metadata
|
||||
document_ids=filtered_doc_ids, # Use filtered document_ids
|
||||
filters=graph_model.filters,
|
||||
created_at=graph_model.created_at,
|
||||
updated_at=graph_model.updated_at,
|
||||
owner=graph_model.owner,
|
||||
access_control=graph_model.access_control,
|
||||
)
|
||||
graphs.append(graph)
|
||||
else:
|
||||
# No system filters, include all graphs
|
||||
graphs = [
|
||||
Graph(
|
||||
id=graph.id,
|
||||
name=graph.name,
|
||||
entities=graph.entities,
|
||||
relationships=graph.relationships,
|
||||
metadata=graph.graph_metadata, # Reference the renamed column
|
||||
system_metadata=graph.system_metadata or {}, # Include system_metadata
|
||||
document_ids=graph.document_ids,
|
||||
filters=graph.filters,
|
||||
created_at=graph.created_at,
|
||||
updated_at=graph.updated_at,
|
||||
owner=graph.owner,
|
||||
access_control=graph.access_control,
|
||||
)
|
||||
for graph in graph_models
|
||||
]
|
||||
|
||||
return graphs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing graphs: {str(e)}")
|
||||
|
@ -28,3 +28,5 @@ class CompletionRequest(BaseModel):
|
||||
max_tokens: Optional[int] = 1000
|
||||
temperature: Optional[float] = 0.7
|
||||
prompt_template: Optional[str] = None
|
||||
folder_name: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
|
@ -27,7 +27,7 @@ class StorageFileInfo(BaseModel):
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Represents a document stored in MongoDB documents collection"""
|
||||
"""Represents a document stored in the database documents collection"""
|
||||
|
||||
external_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
owner: Dict[str, str]
|
||||
@ -44,6 +44,8 @@ class Document(BaseModel):
|
||||
"created_at": datetime.now(UTC),
|
||||
"updated_at": datetime.now(UTC),
|
||||
"version": 1,
|
||||
"folder_name": None,
|
||||
"end_user_id": None,
|
||||
}
|
||||
)
|
||||
"""metadata such as creation date etc."""
|
||||
|
@ -50,6 +50,14 @@ class Graph(BaseModel):
|
||||
entities: List[Entity] = Field(default_factory=list)
|
||||
relationships: List[Relationship] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
system_metadata: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"created_at": datetime.now(UTC),
|
||||
"updated_at": datetime.now(UTC),
|
||||
"folder_name": None,
|
||||
"end_user_id": None,
|
||||
}
|
||||
)
|
||||
document_ids: List[str] = Field(default_factory=list)
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
@ -23,6 +23,8 @@ class RetrieveRequest(BaseModel):
|
||||
include_paths: Optional[bool] = Field(
|
||||
False, description="Whether to include relationship paths in the response"
|
||||
)
|
||||
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
|
||||
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
|
||||
|
||||
|
||||
class CompletionQueryRequest(RetrieveRequest):
|
||||
@ -44,6 +46,8 @@ class IngestTextRequest(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
rules: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
use_colpali: Optional[bool] = None
|
||||
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
|
||||
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
|
||||
|
||||
|
||||
class CreateGraphRequest(BaseModel):
|
||||
@ -66,6 +70,8 @@ class CreateGraphRequest(BaseModel):
|
||||
}
|
||||
}}
|
||||
)
|
||||
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
|
||||
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
|
||||
|
||||
|
||||
class UpdateGraphRequest(BaseModel):
|
||||
@ -81,6 +87,8 @@ class UpdateGraphRequest(BaseModel):
|
||||
None,
|
||||
description="Optional customizations for entity extraction and resolution prompts"
|
||||
)
|
||||
folder_name: Optional[str] = Field(None, description="Optional folder scope for the operation")
|
||||
end_user_id: Optional[str] = Field(None, description="Optional end-user scope for the operation")
|
||||
|
||||
|
||||
class BatchIngestResponse(BaseModel):
|
||||
|
@ -95,6 +95,8 @@ class DocumentService:
|
||||
min_score: float = 0.0,
|
||||
use_reranking: Optional[bool] = None,
|
||||
use_colpali: Optional[bool] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> List[ChunkResult]:
|
||||
"""Retrieve relevant chunks."""
|
||||
settings = get_settings()
|
||||
@ -106,7 +108,14 @@ class DocumentService:
|
||||
logger.info("Generated query embedding")
|
||||
|
||||
# Find authorized documents
|
||||
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters)
|
||||
# Build system filters for folder_name and end_user_id
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters, system_filters)
|
||||
if not doc_ids:
|
||||
logger.info("No authorized documents found")
|
||||
return []
|
||||
@ -194,11 +203,13 @@ class DocumentService:
|
||||
min_score: float = 0.0,
|
||||
use_reranking: Optional[bool] = None,
|
||||
use_colpali: Optional[bool] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> List[DocumentResult]:
|
||||
"""Retrieve relevant documents."""
|
||||
# Get chunks first
|
||||
chunks = await self.retrieve_chunks(
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
|
||||
)
|
||||
# Convert to document results
|
||||
results = await self._create_document_results(auth, chunks)
|
||||
@ -209,7 +220,9 @@ class DocumentService:
|
||||
async def batch_retrieve_documents(
|
||||
self,
|
||||
document_ids: List[str],
|
||||
auth: AuthContext
|
||||
auth: AuthContext,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve multiple documents by their IDs in a single batch operation.
|
||||
@ -224,15 +237,24 @@ class DocumentService:
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
# Build system filters for folder_name and end_user_id
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
# Use the database's batch retrieval method
|
||||
documents = await self.db.get_documents_by_id(document_ids, auth)
|
||||
documents = await self.db.get_documents_by_id(document_ids, auth, system_filters)
|
||||
logger.info(f"Batch retrieved {len(documents)} documents out of {len(document_ids)} requested")
|
||||
return documents
|
||||
|
||||
async def batch_retrieve_chunks(
|
||||
self,
|
||||
chunk_ids: List[ChunkSource],
|
||||
auth: AuthContext
|
||||
auth: AuthContext,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[ChunkResult]:
|
||||
"""
|
||||
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
|
||||
@ -251,7 +273,7 @@ class DocumentService:
|
||||
doc_ids = list({source.document_id for source in chunk_ids})
|
||||
|
||||
# Find authorized documents in a single query
|
||||
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth)
|
||||
authorized_docs = await self.batch_retrieve_documents(doc_ids, auth, folder_name, end_user_id)
|
||||
authorized_doc_ids = {doc.external_id for doc in authorized_docs}
|
||||
|
||||
# Filter sources to only include authorized documents
|
||||
@ -292,6 +314,8 @@ class DocumentService:
|
||||
hop_depth: int = 1,
|
||||
include_paths: bool = False,
|
||||
prompt_overrides: Optional["QueryPromptOverrides"] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> CompletionResponse:
|
||||
"""Generate completion using relevant chunks as context.
|
||||
|
||||
@ -329,11 +353,13 @@ class DocumentService:
|
||||
hop_depth=hop_depth,
|
||||
include_paths=include_paths,
|
||||
prompt_overrides=prompt_overrides,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# Standard retrieval without graph
|
||||
chunks = await self.retrieve_chunks(
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
|
||||
)
|
||||
documents = await self._create_document_results(auth, chunks)
|
||||
|
||||
@ -374,6 +400,8 @@ class DocumentService:
|
||||
auth: AuthContext = None,
|
||||
rules: Optional[List[str]] = None,
|
||||
use_colpali: Optional[bool] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> Document:
|
||||
"""Ingest a text document."""
|
||||
if "write" not in auth.permissions:
|
||||
@ -396,6 +424,12 @@ class DocumentService:
|
||||
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
|
||||
},
|
||||
)
|
||||
|
||||
# Add folder_name and end_user_id to system_metadata if provided
|
||||
if folder_name:
|
||||
doc.system_metadata["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
doc.system_metadata["end_user_id"] = end_user_id
|
||||
logger.debug(f"Created text document record with ID {doc.external_id}")
|
||||
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
@ -459,6 +493,8 @@ class DocumentService:
|
||||
auth: AuthContext,
|
||||
rules: Optional[List[str]] = None,
|
||||
use_colpali: Optional[bool] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
if "write" not in auth.permissions:
|
||||
@ -527,6 +563,12 @@ class DocumentService:
|
||||
},
|
||||
additional_metadata=additional_metadata,
|
||||
)
|
||||
|
||||
# Add folder_name and end_user_id to system_metadata if provided
|
||||
if folder_name:
|
||||
doc.system_metadata["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
doc.system_metadata["end_user_id"] = end_user_id
|
||||
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding with parsing
|
||||
@ -730,7 +772,13 @@ class DocumentService:
|
||||
chunks: List[Chunk],
|
||||
embeddings: List[List[float]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""Helper to create chunk objects"""
|
||||
"""Helper to create chunk objects
|
||||
|
||||
Note: folder_name and end_user_id are not needed in chunk metadata because:
|
||||
1. Filtering by these values happens at the document level in find_authorized_and_filtered_documents
|
||||
2. Vector search is only performed on already authorized and filtered documents
|
||||
3. This approach is more efficient as it reduces the size of chunk metadata
|
||||
"""
|
||||
return [
|
||||
c.to_document_chunk(chunk_number=i, embedding=embedding, document_id=doc_id)
|
||||
for i, (embedding, c) in enumerate(zip(embeddings, chunks))
|
||||
@ -1341,6 +1389,7 @@ class DocumentService:
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
documents: Optional[List[str]] = None,
|
||||
prompt_overrides: Optional[GraphPromptOverrides] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Graph:
|
||||
"""Create a graph from documents.
|
||||
|
||||
@ -1352,6 +1401,8 @@ class DocumentService:
|
||||
auth: Authentication context
|
||||
filters: Optional metadata filters to determine which documents to include
|
||||
documents: Optional list of specific document IDs to include
|
||||
prompt_overrides: Optional customizations for entity extraction and resolution prompts
|
||||
system_filters: Optional system filters like folder_name and end_user_id for scoping
|
||||
|
||||
Returns:
|
||||
Graph: The created graph
|
||||
@ -1364,6 +1415,7 @@ class DocumentService:
|
||||
filters=filters,
|
||||
documents=documents,
|
||||
prompt_overrides=prompt_overrides,
|
||||
system_filters=system_filters,
|
||||
)
|
||||
|
||||
async def update_graph(
|
||||
@ -1373,6 +1425,7 @@ class DocumentService:
|
||||
additional_filters: Optional[Dict[str, Any]] = None,
|
||||
additional_documents: Optional[List[str]] = None,
|
||||
prompt_overrides: Optional[GraphPromptOverrides] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Graph:
|
||||
"""Update an existing graph with new documents.
|
||||
|
||||
@ -1384,6 +1437,8 @@ class DocumentService:
|
||||
auth: Authentication context
|
||||
additional_filters: Optional additional metadata filters to determine which new documents to include
|
||||
additional_documents: Optional list of additional document IDs to include
|
||||
prompt_overrides: Optional customizations for entity extraction and resolution prompts
|
||||
system_filters: Optional system filters like folder_name and end_user_id for scoping
|
||||
|
||||
Returns:
|
||||
Graph: The updated graph
|
||||
@ -1396,6 +1451,7 @@ class DocumentService:
|
||||
additional_filters=additional_filters,
|
||||
additional_documents=additional_documents,
|
||||
prompt_overrides=prompt_overrides,
|
||||
system_filters=system_filters,
|
||||
)
|
||||
|
||||
async def delete_document(self, document_id: str, auth: AuthContext) -> bool:
|
||||
|
@ -68,6 +68,7 @@ class GraphService:
|
||||
additional_filters: Optional[Dict[str, Any]] = None,
|
||||
additional_documents: Optional[List[str]] = None,
|
||||
prompt_overrides: Optional[GraphPromptOverrides] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Graph:
|
||||
"""Update an existing graph with new documents.
|
||||
|
||||
@ -81,10 +82,15 @@ class GraphService:
|
||||
additional_filters: Optional additional metadata filters to determine which new documents to include
|
||||
additional_documents: Optional list of specific additional document IDs to include
|
||||
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
|
||||
|
||||
Returns:
|
||||
Graph: The updated graph
|
||||
"""
|
||||
# Initialize system_filters if None
|
||||
if system_filters is None:
|
||||
system_filters = {}
|
||||
|
||||
if "write" not in auth.permissions:
|
||||
raise PermissionError("User does not have write permission")
|
||||
|
||||
@ -99,7 +105,7 @@ class GraphService:
|
||||
|
||||
# Find new documents to process
|
||||
document_ids = await self._get_new_document_ids(
|
||||
auth, existing_graph, additional_filters, additional_documents
|
||||
auth, existing_graph, additional_filters, additional_documents, system_filters
|
||||
)
|
||||
|
||||
if not document_ids and not explicit_doc_ids:
|
||||
@ -123,7 +129,7 @@ class GraphService:
|
||||
|
||||
# Batch retrieve all documents in a single call
|
||||
document_objects = await document_service.batch_retrieve_documents(
|
||||
all_ids_to_retrieve, auth
|
||||
all_ids_to_retrieve, auth, system_filters.get("folder_name", None), system_filters.get("end_user_id", None)
|
||||
)
|
||||
|
||||
# Process explicit documents if needed
|
||||
@ -150,6 +156,8 @@ class GraphService:
|
||||
if doc_id not in {d.external_id for d in document_objects}
|
||||
],
|
||||
auth,
|
||||
system_filters.get("folder_name", None),
|
||||
system_filters.get("end_user_id", None)
|
||||
)
|
||||
logger.info(f"Additional filtered documents to include: {len(filtered_docs)}")
|
||||
document_objects.extend(filtered_docs)
|
||||
@ -190,23 +198,28 @@ class GraphService:
|
||||
existing_graph: Graph,
|
||||
additional_filters: Optional[Dict[str, Any]] = None,
|
||||
additional_documents: Optional[List[str]] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Set[str]:
|
||||
"""Get IDs of new documents to add to the graph."""
|
||||
# Initialize system_filters if None
|
||||
if system_filters is None:
|
||||
system_filters = {}
|
||||
# Initialize with explicitly specified documents, ensuring it's a set
|
||||
document_ids = set(additional_documents or [])
|
||||
|
||||
# Process documents matching additional filters
|
||||
if additional_filters:
|
||||
filtered_docs = await self.db.get_documents(auth, filters=additional_filters)
|
||||
if additional_filters or system_filters:
|
||||
filtered_docs = await self.db.get_documents(auth, filters=additional_filters, system_filters=system_filters)
|
||||
filter_doc_ids = {doc.external_id for doc in filtered_docs}
|
||||
logger.info(f"Found {len(filter_doc_ids)} documents matching additional filters")
|
||||
logger.info(f"Found {len(filter_doc_ids)} documents matching additional filters and system filters")
|
||||
document_ids.update(filter_doc_ids)
|
||||
|
||||
# Process documents matching the original filters
|
||||
if existing_graph.filters:
|
||||
filtered_docs = await self.db.get_documents(auth, filters=existing_graph.filters)
|
||||
# Original filters shouldn't include system filters, as we're applying them separately
|
||||
filtered_docs = await self.db.get_documents(auth, filters=existing_graph.filters, system_filters=system_filters)
|
||||
orig_filter_doc_ids = {doc.external_id for doc in filtered_docs}
|
||||
logger.info(f"Found {len(orig_filter_doc_ids)} documents matching original filters")
|
||||
logger.info(f"Found {len(orig_filter_doc_ids)} documents matching original filters and system filters")
|
||||
document_ids.update(orig_filter_doc_ids)
|
||||
|
||||
# Get only the document IDs that are not already in the graph
|
||||
@ -384,6 +397,7 @@ class GraphService:
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
documents: Optional[List[str]] = None,
|
||||
prompt_overrides: Optional[GraphPromptOverrides] = None,
|
||||
system_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Graph:
|
||||
"""Create a graph from documents.
|
||||
|
||||
@ -397,10 +411,15 @@ class GraphService:
|
||||
filters: Optional metadata filters to determine which documents to include
|
||||
documents: Optional list of specific document IDs to include
|
||||
prompt_overrides: Optional GraphPromptOverrides with customizations for prompts
|
||||
system_filters: Optional system metadata filters (e.g. folder_name, end_user_id) to determine which documents to include
|
||||
|
||||
Returns:
|
||||
Graph: The created graph
|
||||
"""
|
||||
# Initialize system_filters if None
|
||||
if system_filters is None:
|
||||
system_filters = {}
|
||||
|
||||
if "write" not in auth.permissions:
|
||||
raise PermissionError("User does not have write permission")
|
||||
|
||||
@ -408,15 +427,28 @@ class GraphService:
|
||||
document_ids = set(documents or [])
|
||||
|
||||
# If filters were provided, get matching documents
|
||||
if filters:
|
||||
filtered_docs = await self.db.get_documents(auth, filters=filters)
|
||||
if filters or system_filters:
|
||||
filtered_docs = await self.db.get_documents(auth, filters=filters, system_filters=system_filters)
|
||||
document_ids.update(doc.external_id for doc in filtered_docs)
|
||||
|
||||
if not document_ids:
|
||||
raise ValueError("No documents found matching criteria")
|
||||
|
||||
# Convert system_filters for document retrieval
|
||||
folder_name = system_filters.get("folder_name") if system_filters else None
|
||||
end_user_id = system_filters.get("end_user_id") if system_filters else None
|
||||
|
||||
# Batch retrieve documents for authorization check
|
||||
document_objects = await document_service.batch_retrieve_documents(list(document_ids), auth)
|
||||
document_objects = await document_service.batch_retrieve_documents(
|
||||
list(document_ids),
|
||||
auth,
|
||||
folder_name,
|
||||
end_user_id
|
||||
)
|
||||
|
||||
# Log for debugging
|
||||
logger.info(f"Graph creation with folder_name={folder_name}, end_user_id={end_user_id}")
|
||||
logger.info(f"Documents retrieved: {len(document_objects)} out of {len(document_ids)} requested")
|
||||
if not document_objects:
|
||||
raise ValueError("No authorized documents found matching criteria")
|
||||
|
||||
@ -434,6 +466,13 @@ class GraphService:
|
||||
"admins": [auth.entity_id],
|
||||
},
|
||||
)
|
||||
|
||||
# Add folder_name and end_user_id to system_metadata if provided
|
||||
if system_filters:
|
||||
if "folder_name" in system_filters:
|
||||
graph.system_metadata["folder_name"] = system_filters["folder_name"]
|
||||
if "end_user_id" in system_filters:
|
||||
graph.system_metadata["end_user_id"] = system_filters["end_user_id"]
|
||||
|
||||
# Extract entities and relationships
|
||||
entities, relationships = await self._process_documents_for_entities(
|
||||
@ -868,6 +907,8 @@ class GraphService:
|
||||
hop_depth: int = 1,
|
||||
include_paths: bool = False,
|
||||
prompt_overrides: Optional[QueryPromptOverrides] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> CompletionResponse:
|
||||
"""Generate completion using knowledge graph-enhanced retrieval.
|
||||
|
||||
@ -899,8 +940,15 @@ class GraphService:
|
||||
|
||||
# Validation is now handled by type annotations
|
||||
|
||||
# Get the knowledge graph
|
||||
graph = await self.db.get_graph(graph_name, auth)
|
||||
# Build system filters for scoping
|
||||
system_filters = {}
|
||||
if folder_name:
|
||||
system_filters["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
system_filters["end_user_id"] = end_user_id
|
||||
|
||||
logger.info(f"Querying graph with system_filters: {system_filters}")
|
||||
graph = await self.db.get_graph(graph_name, auth, system_filters=system_filters)
|
||||
if not graph:
|
||||
logger.warning(f"Graph '{graph_name}' not found or not accessible")
|
||||
# Fall back to standard retrieval if graph not found
|
||||
@ -915,12 +963,14 @@ class GraphService:
|
||||
use_reranking=use_reranking,
|
||||
use_colpali=use_colpali,
|
||||
graph_name=None,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
# Parallel approach
|
||||
# 1. Standard vector search
|
||||
vector_chunks = await document_service.retrieve_chunks(
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali
|
||||
query, auth, filters, k, min_score, use_reranking, use_colpali, folder_name, end_user_id
|
||||
)
|
||||
logger.info(f"Vector search retrieved {len(vector_chunks)} chunks")
|
||||
|
||||
@ -990,7 +1040,7 @@ class GraphService:
|
||||
|
||||
# Get specific chunks containing these entities
|
||||
graph_chunks = await self._retrieve_entity_chunks(
|
||||
expanded_entities, auth, filters, document_service
|
||||
expanded_entities, auth, filters, document_service, folder_name, end_user_id
|
||||
)
|
||||
logger.info(f"Retrieved {len(graph_chunks)} chunks containing relevant entities")
|
||||
|
||||
@ -1015,6 +1065,8 @@ class GraphService:
|
||||
auth,
|
||||
graph_name,
|
||||
prompt_overrides,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
return completion_response
|
||||
@ -1143,8 +1195,13 @@ class GraphService:
|
||||
auth: AuthContext,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
document_service,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> List[ChunkResult]:
|
||||
"""Retrieve chunks containing the specified entities."""
|
||||
# Initialize filters if None
|
||||
if filters is None:
|
||||
filters = {}
|
||||
if not entities:
|
||||
return []
|
||||
|
||||
@ -1158,9 +1215,9 @@ class GraphService:
|
||||
|
||||
# Get unique document IDs for authorization check
|
||||
doc_ids = {doc_id for doc_id, _ in entity_chunk_sources}
|
||||
|
||||
# Check document authorization
|
||||
documents = await document_service.batch_retrieve_documents(list(doc_ids), auth)
|
||||
|
||||
# Check document authorization with system filters
|
||||
documents = await document_service.batch_retrieve_documents(list(doc_ids), auth, folder_name, end_user_id)
|
||||
|
||||
# Apply filters if needed
|
||||
authorized_doc_ids = {
|
||||
@ -1178,7 +1235,7 @@ class GraphService:
|
||||
|
||||
# Retrieve and return chunks if we have any valid sources
|
||||
return (
|
||||
await document_service.batch_retrieve_chunks(chunk_sources, auth)
|
||||
await document_service.batch_retrieve_chunks(chunk_sources, auth, folder_name=folder_name, end_user_id=end_user_id)
|
||||
if chunk_sources
|
||||
else []
|
||||
)
|
||||
@ -1198,7 +1255,7 @@ class GraphService:
|
||||
chunk.score = min(1.0, (getattr(chunk, "score", 0.7) or 0.7) * 1.05)
|
||||
|
||||
# Keep the higher-scored version
|
||||
if chunk_key not in all_chunks or chunk.score > all_chunks[chunk_key].score:
|
||||
if chunk_key not in all_chunks or chunk.score > (getattr(all_chunks.get(chunk_key), "score", 0) or 0):
|
||||
all_chunks[chunk_key] = chunk
|
||||
|
||||
# Convert to list, sort by score, and return top k
|
||||
@ -1330,6 +1387,8 @@ class GraphService:
|
||||
auth: Optional[AuthContext] = None,
|
||||
graph_name: Optional[str] = None,
|
||||
prompt_overrides: Optional[QueryPromptOverrides] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> CompletionResponse:
|
||||
"""Generate completion using the retrieved chunks and optional path information."""
|
||||
if not chunks:
|
||||
@ -1370,6 +1429,8 @@ class GraphService:
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
prompt_template=custom_prompt_template,
|
||||
folder_name=folder_name,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
# Get completion from model
|
||||
@ -1387,6 +1448,7 @@ class GraphService:
|
||||
|
||||
# Include graph metadata if paths were requested
|
||||
if include_paths:
|
||||
# Initialize metadata if it doesn't exist
|
||||
if not hasattr(response, "metadata") or response.metadata is None:
|
||||
response.metadata = {}
|
||||
|
||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
TEST_DATA_DIR = Path(__file__).parent / "test_data"
|
||||
JWT_SECRET = "your-secret-key-for-signing-tokens"
|
||||
TEST_USER_ID = "test_user"
|
||||
TEST_POSTGRES_URI = "postgresql+asyncpg://postgres:postgres@localhost:5432/morphik_test"
|
||||
TEST_POSTGRES_URI = "postgresql+asyncpg://morphik@localhost:5432/morphik_test"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -261,6 +261,41 @@ async def test_ingest_text_document_with_metadata(client: AsyncClient, content:
|
||||
return data["external_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_text_document_folder_user(
|
||||
client: AsyncClient,
|
||||
content: str = "Test content for document ingestion with folder and user scoping",
|
||||
metadata: dict = None,
|
||||
folder_name: str = "test_folder",
|
||||
end_user_id: str = "test_user@example.com"
|
||||
):
|
||||
"""Test ingesting a text document with folder and user scoping"""
|
||||
headers = create_auth_header()
|
||||
|
||||
response = await client.post(
|
||||
"/ingest/text",
|
||||
json={
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": end_user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "external_id" in data
|
||||
assert data["content_type"] == "text/plain"
|
||||
assert data["system_metadata"]["folder_name"] == folder_name
|
||||
assert data["system_metadata"]["end_user_id"] == end_user_id
|
||||
|
||||
for key, value in (metadata or {}).items():
|
||||
assert data["metadata"][key] == value
|
||||
|
||||
return data["external_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_pdf(client: AsyncClient):
|
||||
"""Test ingesting a pdf"""
|
||||
@ -1507,6 +1542,196 @@ async def test_query_with_graph(client: AsyncClient):
|
||||
assert response_no_graph.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_with_folder_and_user_scope(client: AsyncClient):
|
||||
"""Test knowledge graph with folder and user scoping."""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test folder
|
||||
folder_name = "test_graph_folder"
|
||||
|
||||
# Test user
|
||||
user_id = "graph_test_user@example.com"
|
||||
|
||||
# Ingest documents into folder with user scope using our helper function
|
||||
doc_id1 = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content="Tesla is an electric vehicle manufacturer. Elon Musk is the CEO of Tesla.",
|
||||
metadata={"graph_scope_test": True},
|
||||
folder_name=folder_name,
|
||||
end_user_id=user_id
|
||||
)
|
||||
|
||||
doc_id2 = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content="SpaceX develops spacecraft and rockets. Elon Musk is also the CEO of SpaceX.",
|
||||
metadata={"graph_scope_test": True},
|
||||
folder_name=folder_name,
|
||||
end_user_id=user_id
|
||||
)
|
||||
|
||||
# Also ingest a document outside the folder/user scope
|
||||
_ = await test_ingest_text_document_with_metadata(
|
||||
client,
|
||||
content="Elon Musk also founded Neuralink, a neurotechnology company.",
|
||||
metadata={"graph_scope_test": True}
|
||||
)
|
||||
|
||||
# Create a graph with folder and user scope
|
||||
graph_name = "test_scoped_graph"
|
||||
response = await client.post(
|
||||
"/graph/create",
|
||||
json={
|
||||
"name": graph_name,
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
graph = response.json()
|
||||
|
||||
# Verify graph was created with proper scoping
|
||||
assert graph["name"] == graph_name
|
||||
assert len(graph["document_ids"]) == 2
|
||||
assert all(doc_id in graph["document_ids"] for doc_id in [doc_id1, doc_id2])
|
||||
|
||||
# Verify we have the expected entities
|
||||
entity_labels = [entity["label"].lower() for entity in graph["entities"]]
|
||||
assert any("tesla" in label for label in entity_labels)
|
||||
assert any("spacex" in label for label in entity_labels)
|
||||
assert any("elon musk" in label for label in entity_labels)
|
||||
|
||||
# First, let's check the retrieved chunks directly to verify scope is working
|
||||
retrieve_response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "What companies does Elon Musk lead?",
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert retrieve_response.status_code == 200
|
||||
retrieved_chunks = retrieve_response.json()
|
||||
|
||||
# Verify that none of the retrieved chunks contain "Neuralink"
|
||||
for chunk in retrieved_chunks:
|
||||
assert "neuralink" not in chunk["content"].lower()
|
||||
|
||||
# First try querying without a graph to see if RAG works with just folder/user scope
|
||||
response_no_graph = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What companies does Elon Musk lead?",
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response_no_graph.status_code == 200
|
||||
result_no_graph = response_no_graph.json()
|
||||
|
||||
# Verify the completion has the expected content
|
||||
completion_no_graph = result_no_graph["completion"].lower()
|
||||
print("Completion without graph:")
|
||||
print(completion_no_graph)
|
||||
assert "tesla" in completion_no_graph
|
||||
assert "spacex" in completion_no_graph
|
||||
assert "neuralink" not in completion_no_graph
|
||||
|
||||
# Now test querying with graph and folder/user scope
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What companies does Elon Musk lead?",
|
||||
"graph_name": graph_name,
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
||||
# Log source chunks and graph information used for completion
|
||||
print("\nSource chunks for graph-based completion:")
|
||||
for source in result["sources"]:
|
||||
print(f"Document ID: {source['document_id']}, Chunk: {source['chunk_number']}")
|
||||
|
||||
# Check if there's graph metadata in the response
|
||||
if result.get("metadata") and "graph" in result.get("metadata", {}):
|
||||
print("\nGraph metadata used:")
|
||||
print(result["metadata"]["graph"])
|
||||
|
||||
# Verify the completion has the expected content
|
||||
completion = result["completion"].lower()
|
||||
print("\nCompletion with graph:")
|
||||
print(completion)
|
||||
assert "tesla" in completion
|
||||
assert "spacex" in completion
|
||||
|
||||
# Verify Neuralink isn't included (it was outside folder/user scope)
|
||||
assert "neuralink" not in completion
|
||||
|
||||
# Test updating the graph with folder and user scope
|
||||
doc_id3 = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content="The Boring Company was founded by Elon Musk in 2016.",
|
||||
metadata={"graph_scope_test": True},
|
||||
folder_name=folder_name,
|
||||
end_user_id=user_id
|
||||
)
|
||||
|
||||
# Update the graph
|
||||
update_response = await client.post(
|
||||
f"/graph/{graph_name}/update",
|
||||
json={
|
||||
"additional_documents": [doc_id3],
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert update_response.status_code == 200
|
||||
updated_graph = update_response.json()
|
||||
|
||||
# Verify graph was updated
|
||||
assert updated_graph["name"] == graph_name
|
||||
assert len(updated_graph["document_ids"]) == 3
|
||||
assert all(doc_id in updated_graph["document_ids"] for doc_id in [doc_id1, doc_id2, doc_id3])
|
||||
|
||||
# Verify new entity was added
|
||||
updated_entity_labels = [entity["label"].lower() for entity in updated_graph["entities"]]
|
||||
assert any("boring company" in label for label in updated_entity_labels)
|
||||
|
||||
# Test querying with updated graph
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "List all companies founded or led by Elon Musk",
|
||||
"graph_name": graph_name,
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
updated_result = response.json()
|
||||
|
||||
# Verify the completion includes the new company
|
||||
updated_completion = updated_result["completion"].lower()
|
||||
assert "tesla" in updated_completion
|
||||
assert "spacex" in updated_completion
|
||||
assert "boring company" in updated_completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_ingest_with_shared_metadata(
|
||||
client: AsyncClient
|
||||
@ -1804,6 +2029,429 @@ async def test_batch_ingest_sequential_vs_parallel(
|
||||
assert len(result["errors"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_folder_scoping(client: AsyncClient):
|
||||
"""Test document operations with folder scoping."""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test folder 1
|
||||
folder1_name = "test_folder_1"
|
||||
folder1_content = "This is content in test folder 1."
|
||||
|
||||
# Test folder 2
|
||||
folder2_name = "test_folder_2"
|
||||
folder2_content = "This is different content in test folder 2."
|
||||
|
||||
# Ingest document into folder 1 using our helper function
|
||||
doc1_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=folder1_content,
|
||||
metadata={"folder_test": True},
|
||||
folder_name=folder1_name,
|
||||
end_user_id=None
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc1_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc1 = response.json()
|
||||
assert doc1["system_metadata"]["folder_name"] == folder1_name
|
||||
|
||||
# Ingest document into folder 2 using our helper function
|
||||
doc2_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=folder2_content,
|
||||
metadata={"folder_test": True},
|
||||
folder_name=folder2_name,
|
||||
end_user_id=None
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc2_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc2 = response.json()
|
||||
assert doc2["system_metadata"]["folder_name"] == folder2_name
|
||||
|
||||
# Verify we can get documents by folder
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"folder_test": True},
|
||||
headers=headers,
|
||||
params={"folder_name": folder1_name}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
folder1_docs = response.json()
|
||||
assert len(folder1_docs) == 1
|
||||
assert folder1_docs[0]["external_id"] == doc1_id
|
||||
|
||||
# Verify other folder's document isn't in results
|
||||
assert not any(doc["external_id"] == doc2_id for doc in folder1_docs)
|
||||
|
||||
# Test querying with folder scope
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What folder is this content in?",
|
||||
"folder_name": folder1_name
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "completion" in result
|
||||
|
||||
# Test folder-specific chunk retrieval
|
||||
response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "folder content",
|
||||
"folder_name": folder2_name
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
chunks = response.json()
|
||||
assert len(chunks) > 0
|
||||
assert folder2_content in chunks[0]["content"]
|
||||
|
||||
# Test document update with folder preservation
|
||||
updated_content = "This is updated content in test folder 1."
|
||||
response = await client.post(
|
||||
f"/documents/{doc1_id}/update_text",
|
||||
json={
|
||||
"content": updated_content,
|
||||
"metadata": {"updated": True},
|
||||
"folder_name": folder1_name # This should match original folder
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
updated_doc = response.json()
|
||||
assert updated_doc["system_metadata"]["folder_name"] == folder1_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_scoping(client: AsyncClient):
|
||||
"""Test document operations with end-user scoping."""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test user 1
|
||||
user1_id = "test_user_1@example.com"
|
||||
user1_content = "This is content created by test user 1."
|
||||
|
||||
# Test user 2
|
||||
user2_id = "test_user_2@example.com"
|
||||
user2_content = "This is different content created by test user 2."
|
||||
|
||||
# Ingest document for user 1 using our helper function
|
||||
doc1_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=user1_content,
|
||||
metadata={"user_test": True},
|
||||
folder_name=None,
|
||||
end_user_id=user1_id
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc1_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc1 = response.json()
|
||||
assert doc1["system_metadata"]["end_user_id"] == user1_id
|
||||
|
||||
# Ingest document for user 2 using our helper function
|
||||
doc2_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=user2_content,
|
||||
metadata={"user_test": True},
|
||||
folder_name=None,
|
||||
end_user_id=user2_id
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc2_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc2 = response.json()
|
||||
assert doc2["system_metadata"]["end_user_id"] == user2_id
|
||||
|
||||
# Verify we can get documents by user
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"user_test": True},
|
||||
headers=headers,
|
||||
params={"end_user_id": user1_id}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user1_docs = response.json()
|
||||
assert len(user1_docs) == 1
|
||||
assert user1_docs[0]["external_id"] == doc1_id
|
||||
|
||||
# Verify other user's document isn't in results
|
||||
assert not any(doc["external_id"] == doc2_id for doc in user1_docs)
|
||||
|
||||
# Test querying with user scope
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What is my content?",
|
||||
"end_user_id": user1_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "completion" in result
|
||||
|
||||
# Test updating document with user preservation
|
||||
updated_content = "This is updated content by test user 1."
|
||||
response = await client.post(
|
||||
f"/documents/{doc1_id}/update_text",
|
||||
json={
|
||||
"content": updated_content,
|
||||
"metadata": {"updated": True},
|
||||
"end_user_id": user1_id # Should preserve the user
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
updated_doc = response.json()
|
||||
assert updated_doc["system_metadata"]["end_user_id"] == user1_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_folder_and_user_scoping(client: AsyncClient):
|
||||
"""Test document operations with combined folder and user scoping."""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test folder
|
||||
folder_name = "test_combined_folder"
|
||||
|
||||
# Test users
|
||||
user1_id = "combined_test_user_1@example.com"
|
||||
user2_id = "combined_test_user_2@example.com"
|
||||
|
||||
# Ingest document for user 1 in folder using our new helper function
|
||||
user1_content = "This is content by user 1 in the combined test folder."
|
||||
doc1_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=user1_content,
|
||||
metadata={"combined_test": True},
|
||||
folder_name=folder_name,
|
||||
end_user_id=user1_id
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc1_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc1 = response.json()
|
||||
assert doc1["system_metadata"]["folder_name"] == folder_name
|
||||
assert doc1["system_metadata"]["end_user_id"] == user1_id
|
||||
|
||||
# Ingest document for user 2 in folder using our new helper function
|
||||
user2_content = "This is content by user 2 in the combined test folder."
|
||||
doc2_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=user2_content,
|
||||
metadata={"combined_test": True},
|
||||
folder_name=folder_name,
|
||||
end_user_id=user2_id
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{doc2_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
doc2 = response.json()
|
||||
assert doc2["system_metadata"]["folder_name"] == folder_name
|
||||
assert doc2["system_metadata"]["end_user_id"] == user2_id
|
||||
|
||||
# Get all documents in folder
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"combined_test": True},
|
||||
headers=headers,
|
||||
params={"folder_name": folder_name}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
folder_docs = response.json()
|
||||
assert len(folder_docs) == 2
|
||||
|
||||
# Get user 1's documents in the folder
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"combined_test": True},
|
||||
headers=headers,
|
||||
params={"folder_name": folder_name, "end_user_id": user1_id}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user1_folder_docs = response.json()
|
||||
assert len(user1_folder_docs) == 1
|
||||
assert user1_folder_docs[0]["external_id"] == doc1_id
|
||||
|
||||
# Test querying with combined scope
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "What is in this folder for this user?",
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user2_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "completion" in result
|
||||
|
||||
# Test retrieving chunks with combined scope
|
||||
response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "combined test folder",
|
||||
"folder_name": folder_name,
|
||||
"end_user_id": user1_id
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
chunks = response.json()
|
||||
assert len(chunks) > 0
|
||||
# Should only have user 1's content
|
||||
assert any(user1_content in chunk["content"] for chunk in chunks)
|
||||
assert not any(user2_content in chunk["content"] for chunk in chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_metadata_filter_behavior(client: AsyncClient):
|
||||
"""Test detailed behavior of system_metadata filtering."""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Create documents with different system metadata combinations
|
||||
|
||||
# Document with folder only
|
||||
folder_only_content = "This document has only folder in system metadata."
|
||||
folder_only_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=folder_only_content,
|
||||
metadata={"filter_test": True},
|
||||
folder_name="test_filter_folder",
|
||||
end_user_id=None # Only folder, no user
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{folder_only_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
folder_only_doc = response.json()
|
||||
|
||||
# Document with user only
|
||||
user_only_content = "This document has only user in system metadata."
|
||||
user_only_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=user_only_content,
|
||||
metadata={"filter_test": True},
|
||||
folder_name=None, # No folder, only user
|
||||
end_user_id="test_filter_user@example.com"
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{user_only_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
user_only_doc = response.json()
|
||||
|
||||
# Document with both folder and user
|
||||
combined_content = "This document has both folder and user in system metadata."
|
||||
combined_id = await test_ingest_text_document_folder_user(
|
||||
client,
|
||||
content=combined_content,
|
||||
metadata={"filter_test": True},
|
||||
folder_name="test_filter_folder",
|
||||
end_user_id="test_filter_user@example.com"
|
||||
)
|
||||
|
||||
# Get the document to verify
|
||||
response = await client.get(f"/documents/{combined_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
combined_doc = response.json()
|
||||
|
||||
# Test queries with different filter combinations
|
||||
|
||||
# Filter by folder only
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"filter_test": True},
|
||||
headers=headers,
|
||||
params={"folder_name": "test_filter_folder"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
folder_filtered_docs = response.json()
|
||||
folder_doc_ids = [doc["external_id"] for doc in folder_filtered_docs]
|
||||
assert folder_only_id in folder_doc_ids
|
||||
assert combined_id in folder_doc_ids
|
||||
assert user_only_id not in folder_doc_ids
|
||||
|
||||
# Filter by user only
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"filter_test": True},
|
||||
headers=headers,
|
||||
params={"end_user_id": "test_filter_user@example.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
user_filtered_docs = response.json()
|
||||
user_doc_ids = [doc["external_id"] for doc in user_filtered_docs]
|
||||
assert user_only_id in user_doc_ids
|
||||
assert combined_id in user_doc_ids
|
||||
assert folder_only_id not in user_doc_ids
|
||||
|
||||
# Filter by both folder and user
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={"filter_test": True},
|
||||
headers=headers,
|
||||
params={
|
||||
"folder_name": "test_filter_folder",
|
||||
"end_user_id": "test_filter_user@example.com"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
combined_filtered_docs = response.json()
|
||||
combined_doc_ids = [doc["external_id"] for doc in combined_filtered_docs]
|
||||
assert len(combined_filtered_docs) == 1
|
||||
assert combined_id in combined_doc_ids
|
||||
assert folder_only_id not in combined_doc_ids
|
||||
assert user_only_id not in combined_doc_ids
|
||||
|
||||
# Test with chunk retrieval
|
||||
response = await client.post(
|
||||
"/retrieve/chunks",
|
||||
json={
|
||||
"query": "system metadata",
|
||||
"folder_name": "test_filter_folder",
|
||||
"end_user_id": "test_filter_user@example.com"
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
chunks = response.json()
|
||||
assert len(chunks) > 0
|
||||
# Should only have the combined document content
|
||||
assert any(combined_content in chunk["content"] for chunk in chunks)
|
||||
assert not any(folder_only_content in chunk["content"] for chunk in chunks)
|
||||
assert not any(user_only_content in chunk["content"] for chunk in chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_document(client: AsyncClient):
|
||||
"""Test deleting a document and verifying it's gone."""
|
||||
|
@ -1,183 +0,0 @@
|
||||
from typing import List, Optional, Tuple
|
||||
import logging
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo.errors import PyMongoError
|
||||
|
||||
from .base_vector_store import BaseVectorStore
|
||||
from core.models.chunk import DocumentChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
"""MongoDB Atlas Vector Search implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
database_name: str,
|
||||
collection_name: str = "document_chunks",
|
||||
index_name: str = "vector_index",
|
||||
):
|
||||
"""Initialize MongoDB connection for vector storage."""
|
||||
self.client = AsyncIOMotorClient(uri)
|
||||
self.db = self.client[database_name]
|
||||
self.collection = self.db[collection_name]
|
||||
self.index_name = index_name
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize vector search index if needed."""
|
||||
try:
|
||||
# Create basic indexes
|
||||
await self.collection.create_index("document_id")
|
||||
await self.collection.create_index("chunk_number")
|
||||
|
||||
# Note: Vector search index must be created via Atlas UI or API
|
||||
# as it requires specific configuration
|
||||
|
||||
logger.info("MongoDB vector store indexes initialized")
|
||||
return True
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error initializing vector store indexes: {str(e)}")
|
||||
return False
|
||||
|
||||
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||
"""Store document chunks with their embeddings."""
|
||||
try:
|
||||
if not chunks:
|
||||
return True, []
|
||||
|
||||
# Convert chunks to dicts
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
doc = chunk.model_dump()
|
||||
# Ensure we have required fields
|
||||
if not doc.get("embedding"):
|
||||
logger.error(
|
||||
f"Missing embedding for chunk " f"{chunk.document_id}-{chunk.chunk_number}"
|
||||
)
|
||||
continue
|
||||
documents.append(doc)
|
||||
|
||||
if documents:
|
||||
# Use ordered=False to continue even if some inserts fail
|
||||
result = await self.collection.insert_many(documents, ordered=False)
|
||||
return len(result.inserted_ids) > 0, [str(id) for id in result.inserted_ids]
|
||||
else:
|
||||
logger.error(f"No documents to store - here is the input: {chunks}")
|
||||
return False, []
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error storing embeddings: {str(e)}")
|
||||
return False, []
|
||||
|
||||
async def query_similar(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
k: int,
|
||||
doc_ids: Optional[List[str]] = None,
|
||||
) -> List[DocumentChunk]:
|
||||
"""Find similar chunks using MongoDB Atlas Vector Search."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Searching in database {self.db.name} " f"collection {self.collection.name}"
|
||||
)
|
||||
logger.debug(f"Query vector looks like: {query_embedding}")
|
||||
logger.debug(f"Doc IDs: {doc_ids}")
|
||||
logger.debug(f"K is: {k}")
|
||||
logger.debug(f"Index is: {self.index_name}")
|
||||
|
||||
# Vector search pipeline
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": self.index_name,
|
||||
"path": "embedding",
|
||||
"queryVector": query_embedding,
|
||||
"numCandidates": k * 40, # Get more candidates
|
||||
"limit": k,
|
||||
"filter": {"document_id": {"$in": doc_ids}} if doc_ids else {},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"score": {"$meta": "vectorSearchScore"},
|
||||
"document_id": 1,
|
||||
"chunk_number": 1,
|
||||
"content": 1,
|
||||
"metadata": 1,
|
||||
"_id": 0,
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# Execute search
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
chunks = []
|
||||
|
||||
async for result in cursor:
|
||||
chunk = DocumentChunk(
|
||||
document_id=result["document_id"],
|
||||
chunk_number=result["chunk_number"],
|
||||
content=result["content"],
|
||||
embedding=[], # Don't send embeddings back
|
||||
metadata=result.get("metadata", {}),
|
||||
score=result.get("score", 0.0),
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"MongoDB error: {e._message}")
|
||||
logger.error(f"Error querying similar chunks: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def get_chunks_by_id(
|
||||
self,
|
||||
chunk_identifiers: List[Tuple[str, int]],
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Retrieve specific chunks by document ID and chunk number in a single database query.
|
||||
|
||||
Args:
|
||||
chunk_identifiers: List of (document_id, chunk_number) tuples
|
||||
|
||||
Returns:
|
||||
List of DocumentChunk objects
|
||||
"""
|
||||
try:
|
||||
if not chunk_identifiers:
|
||||
return []
|
||||
|
||||
# Create a query with $or to find multiple chunks in a single query
|
||||
query = {"$or": []}
|
||||
for doc_id, chunk_num in chunk_identifiers:
|
||||
query["$or"].append({
|
||||
"document_id": doc_id,
|
||||
"chunk_number": chunk_num
|
||||
})
|
||||
|
||||
logger.info(f"Batch retrieving {len(chunk_identifiers)} chunks with a single query")
|
||||
|
||||
# Find all matching chunks in a single database query
|
||||
cursor = self.collection.find(query)
|
||||
chunks = []
|
||||
|
||||
async for result in cursor:
|
||||
chunk = DocumentChunk(
|
||||
document_id=result["document_id"],
|
||||
chunk_number=result["chunk_number"],
|
||||
content=result["content"],
|
||||
embedding=[], # Don't send embeddings back
|
||||
metadata=result.get("metadata", {}),
|
||||
score=0.0, # No relevance score for direct retrieval
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
logger.info(f"Found {len(chunks)} chunks in batch retrieval")
|
||||
return chunks
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving chunks by ID: {str(e)}")
|
||||
return []
|
82
examples/multi_app_user_scoping.py
Normal file
82
examples/multi_app_user_scoping.py
Normal file
@ -0,0 +1,82 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from morphik import Morphik
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Connect to Morphik
|
||||
db = Morphik(os.getenv("MORPHIK_URI"), timeout=10000, is_local=True)
|
||||
|
||||
print("========== Customer Support Example ==========")
|
||||
# Create a folder for application data
|
||||
app_folder = db.create_folder("customer-support")
|
||||
print(f"Created folder: {app_folder.name}")
|
||||
|
||||
# Ingest documents into the folder
|
||||
folder_doc = app_folder.ingest_text(
|
||||
"Customer reported an issue with login functionality. Steps to reproduce: "
|
||||
"1. Go to login page, 2. Enter credentials, 3. Click login button.",
|
||||
filename="ticket-001.txt",
|
||||
metadata={"category": "bug", "priority": "high", "status": "open"}
|
||||
)
|
||||
print(f"Ingested document into folder: {folder_doc.external_id}")
|
||||
|
||||
# Perform a query in the folder context
|
||||
folder_response = app_folder.query(
|
||||
"What issues have been reported?",
|
||||
k=2
|
||||
)
|
||||
print("\nFolder Query Results:")
|
||||
print(folder_response.completion)
|
||||
|
||||
# Get statistics for the folder
|
||||
folder_docs = app_folder.list_documents()
|
||||
print(f"\nFolder Statistics: {len(folder_docs)} documents in '{app_folder.name}'")
|
||||
|
||||
print("\n========== User Scoping Example ==========")
|
||||
# Create a user scope
|
||||
user_email = "support@example.com"
|
||||
user = db.signin(user_email)
|
||||
print(f"Created user scope for: {user.end_user_id}")
|
||||
|
||||
# Ingest a document as this user
|
||||
user_doc = user.ingest_text(
|
||||
"User requested information about premium features. They are interested in the collaboration tools.",
|
||||
filename="inquiry-001.txt",
|
||||
metadata={"category": "inquiry", "priority": "medium", "status": "open"}
|
||||
)
|
||||
print(f"Ingested document as user: {user_doc.external_id}")
|
||||
|
||||
# Query as this user
|
||||
user_response = user.query(
|
||||
"What customer inquiries do we have?",
|
||||
k=2
|
||||
)
|
||||
print("\nUser Query Results:")
|
||||
print(user_response.completion)
|
||||
|
||||
# Get documents for this user
|
||||
user_docs = user.list_documents()
|
||||
print(f"\nUser Statistics: {len(user_docs)} documents for user '{user.end_user_id}'")
|
||||
|
||||
print("\n========== Combined Folder and User Scoping ==========")
|
||||
# Create a user scoped to a specific folder
|
||||
folder_user = app_folder.signin(user_email)
|
||||
print(f"Created user scope for {folder_user.end_user_id} in folder {folder_user.folder_name}")
|
||||
|
||||
# Ingest a document as this user in the folder context
|
||||
folder_user_doc = folder_user.ingest_text(
|
||||
"Customer called to follow up on ticket-001. They are still experiencing the login issue on Chrome.",
|
||||
filename="ticket-002.txt",
|
||||
metadata={"category": "follow-up", "priority": "high", "status": "open"}
|
||||
)
|
||||
print(f"Ingested document as user in folder: {folder_user_doc.external_id}")
|
||||
|
||||
# Query as this user in the folder context
|
||||
folder_user_response = folder_user.query(
|
||||
"What high priority issues require attention?",
|
||||
k=2
|
||||
)
|
||||
print("\nFolder User Query Results:")
|
||||
print(folder_user_response.completion)
|
@ -9,15 +9,12 @@ import boto3
|
||||
import botocore
|
||||
import tomli # for reading toml files
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure, OperationFailure
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
# Force reload of environment variables
|
||||
load_dotenv(find_dotenv(), override=True)
|
||||
|
||||
# Set up argument parser
|
||||
parser = argparse.ArgumentParser(description="Setup S3 bucket and MongoDB collections")
|
||||
parser = argparse.ArgumentParser(description="Setup S3 bucket")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
parser.add_argument("--quiet", action="store_true", help="Only show warning and error logs")
|
||||
args = parser.parse_args()
|
||||
@ -48,16 +45,6 @@ with open(config_path, "rb") as f:
|
||||
STORAGE_PROVIDER = CONFIG["storage"]["provider"]
|
||||
DATABASE_PROVIDER = CONFIG["database"]["provider"]
|
||||
|
||||
# MongoDB specific config
|
||||
if "mongodb" in CONFIG["database"]:
|
||||
DATABASE_NAME = CONFIG["database"]["mongodb"]["database_name"]
|
||||
DOCUMENTS_COLLECTION = "documents"
|
||||
CHUNKS_COLLECTION = "document_chunks"
|
||||
if "mongodb" in CONFIG["vector_store"]:
|
||||
VECTOR_DIMENSIONS = CONFIG["embedding"]["dimensions"]
|
||||
VECTOR_INDEX_NAME = "vector_index"
|
||||
SIMILARITY_METRIC = CONFIG["embedding"]["similarity_metric"]
|
||||
|
||||
# Extract storage-specific configuration
|
||||
if STORAGE_PROVIDER == "aws-s3":
|
||||
DEFAULT_REGION = CONFIG["storage"]["region"]
|
||||
@ -117,69 +104,6 @@ def bucket_exists(s3_client, bucket_name):
|
||||
# raise e
|
||||
|
||||
|
||||
def setup_mongodb():
|
||||
"""
|
||||
Set up MongoDB database, documents collection, and vector index on documents_chunk collection.
|
||||
"""
|
||||
# Load MongoDB URI from .env file
|
||||
mongo_uri = os.getenv("MONGODB_URI")
|
||||
if not mongo_uri:
|
||||
raise ValueError("MONGODB_URI not found in .env file.")
|
||||
|
||||
try:
|
||||
# Connect to MongoDB
|
||||
client = MongoClient(mongo_uri)
|
||||
client.admin.command("ping") # Check connection
|
||||
LOGGER.info("Connected to MongoDB successfully.")
|
||||
|
||||
# Create or access the database
|
||||
db = client[DATABASE_NAME]
|
||||
LOGGER.info(f"Database '{DATABASE_NAME}' ready.")
|
||||
|
||||
# Create 'documents' collection
|
||||
if DOCUMENTS_COLLECTION not in db.list_collection_names():
|
||||
db.create_collection(DOCUMENTS_COLLECTION)
|
||||
LOGGER.info(f"Collection '{DOCUMENTS_COLLECTION}' created.")
|
||||
else:
|
||||
LOGGER.info(f"Collection '{DOCUMENTS_COLLECTION}' already exists.")
|
||||
|
||||
# Create 'documents_chunk' collection with vector index
|
||||
if CHUNKS_COLLECTION not in db.list_collection_names():
|
||||
db.create_collection(CHUNKS_COLLECTION)
|
||||
LOGGER.info(f"Collection '{CHUNKS_COLLECTION}' created.")
|
||||
else:
|
||||
LOGGER.info(f"Collection '{CHUNKS_COLLECTION}' already exists.")
|
||||
|
||||
vector_index_definition = {
|
||||
"fields": [
|
||||
{
|
||||
"numDimensions": VECTOR_DIMENSIONS,
|
||||
"path": "embedding",
|
||||
"similarity": SIMILARITY_METRIC,
|
||||
"type": "vector",
|
||||
},
|
||||
{"path": "document_id", "type": "filter"},
|
||||
]
|
||||
}
|
||||
vector_index = SearchIndexModel(
|
||||
name=VECTOR_INDEX_NAME,
|
||||
definition=vector_index_definition,
|
||||
type="vectorSearch",
|
||||
)
|
||||
db[CHUNKS_COLLECTION].create_search_index(model=vector_index)
|
||||
LOGGER.info("Vector index 'vector_index' created on 'documents_chunk' collection.")
|
||||
|
||||
except ConnectionFailure:
|
||||
LOGGER.error("Failed to connect to MongoDB. Check your MongoDB URI and network connection.")
|
||||
except OperationFailure as e:
|
||||
LOGGER.error(f"MongoDB operation failed: {e}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error: {e}")
|
||||
finally:
|
||||
client.close()
|
||||
LOGGER.info("MongoDB connection closed.")
|
||||
|
||||
|
||||
def setup():
|
||||
# Setup S3 if configured
|
||||
if STORAGE_PROVIDER == "aws-s3":
|
||||
@ -188,16 +112,11 @@ def setup():
|
||||
LOGGER.info("S3 bucket setup completed.")
|
||||
|
||||
# Setup database based on provider
|
||||
match DATABASE_PROVIDER:
|
||||
case "mongodb":
|
||||
LOGGER.info("Setting up MongoDB...")
|
||||
setup_mongodb()
|
||||
LOGGER.info("MongoDB setup completed.")
|
||||
case "postgres":
|
||||
LOGGER.info("Postgres is setup on database intialization - nothing to do here!")
|
||||
case _:
|
||||
LOGGER.error(f"Unsupported database provider: {DATABASE_PROVIDER}")
|
||||
raise ValueError(f"Unsupported database provider: {DATABASE_PROVIDER}")
|
||||
if DATABASE_PROVIDER != "postgres":
|
||||
LOGGER.error(f"Unsupported database provider: {DATABASE_PROVIDER}")
|
||||
raise ValueError(f"Unsupported database provider: {DATABASE_PROVIDER}")
|
||||
|
||||
LOGGER.info("Postgres is setup on database initialization - nothing to do here!")
|
||||
|
||||
LOGGER.info("Setup completed successfully. Feel free to start the server now!")
|
||||
|
||||
|
@ -151,7 +151,6 @@ matplotlib-inline==0.1.7
|
||||
mdurl==0.1.2
|
||||
monotonic==1.6
|
||||
more-itertools==10.5.0
|
||||
motor==3.4.0
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.5
|
||||
multiprocess==0.70.16
|
||||
@ -237,7 +236,6 @@ pydeck==0.9.1
|
||||
pyee==12.1.1
|
||||
Pygments==2.18.0
|
||||
PyJWT==2.9.0
|
||||
pymongo==4.7.1
|
||||
pypandoc==1.13
|
||||
pyparsing==3.1.2
|
||||
pypdf==4.3.1
|
||||
|
@ -1,72 +0,0 @@
|
||||
from pymongo import MongoClient
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import datetime
|
||||
|
||||
|
||||
def test_mongo_operations():
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get MongoDB URI from environment variable
|
||||
mongo_uri = os.getenv("MONGODB_URI")
|
||||
if not mongo_uri:
|
||||
raise ValueError("MONGODB_URI environment variable not set")
|
||||
|
||||
try:
|
||||
# Connect to MongoDB
|
||||
client = MongoClient(mongo_uri)
|
||||
|
||||
# Test connection
|
||||
client.admin.command("ping")
|
||||
print("✅ Connected successfully to MongoDB")
|
||||
|
||||
# Get database and collection
|
||||
db = client.brandsyncaidb # Using a test database
|
||||
collection = db.kb_chunked_embeddings
|
||||
|
||||
# Insert a single document
|
||||
test_doc = {
|
||||
"name": "Test Document",
|
||||
"timestamp": datetime.datetime.now(),
|
||||
"value": 42,
|
||||
}
|
||||
|
||||
result = collection.insert_one(test_doc)
|
||||
print(f"✅ Inserted document with ID: {result.inserted_id}")
|
||||
|
||||
# Insert multiple documents
|
||||
test_docs = [
|
||||
{"name": "Doc 1", "value": 1},
|
||||
{"name": "Doc 2", "value": 2},
|
||||
{"name": "Doc 3", "value": 3},
|
||||
]
|
||||
|
||||
result = collection.insert_many(test_docs)
|
||||
print(f"✅ Inserted {len(result.inserted_ids)} documents")
|
||||
|
||||
# Retrieve documents
|
||||
print("\nRetrieving documents:")
|
||||
for doc in collection.find():
|
||||
print(f"Found document: {doc}")
|
||||
|
||||
# Find specific documents
|
||||
print("\nFinding documents with value >= 2:")
|
||||
query = {"value": {"$gte": 2}}
|
||||
for doc in collection.find(query):
|
||||
print(f"Found document: {doc}")
|
||||
|
||||
# Clean up - delete all test documents
|
||||
# DON'T DELETE IF It'S BRANDSYNCAI
|
||||
# result = collection.delete_many({})
|
||||
print(f"\n✅ Cleaned up {result.deleted_count} test documents")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {str(e)}")
|
||||
finally:
|
||||
client.close()
|
||||
print("\n✅ Connection closed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mongo_operations()
|
@ -12,4 +12,4 @@ __all__ = [
|
||||
"Document",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.2"
|
||||
|
507
sdks/python/morphik/_internal.py
Normal file
507
sdks/python/morphik/_internal.py
Normal file
@ -0,0 +1,507 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from io import BytesIO, IOBase
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImage
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple, BinaryIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .models import (
|
||||
Document,
|
||||
ChunkResult,
|
||||
DocumentResult,
|
||||
CompletionResponse,
|
||||
IngestTextRequest,
|
||||
ChunkSource,
|
||||
Graph,
|
||||
# Prompt override models
|
||||
GraphPromptOverrides,
|
||||
)
|
||||
from .rules import Rule
|
||||
|
||||
# Type alias for rules
|
||||
RuleOrDict = Union[Rule, Dict[str, Any]]
|
||||
|
||||
|
||||
class FinalChunkResult(BaseModel):
|
||||
content: str | PILImage = Field(..., description="Chunk content")
|
||||
score: float = Field(..., description="Relevance score")
|
||||
document_id: str = Field(..., description="Parent document ID")
|
||||
chunk_number: int = Field(..., description="Chunk sequence number")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
||||
content_type: str = Field(..., description="Content type")
|
||||
filename: Optional[str] = Field(None, description="Original filename")
|
||||
download_url: Optional[str] = Field(None, description="URL to download full document")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class _MorphikClientLogic:
|
||||
"""
|
||||
Internal shared logic for Morphik clients.
|
||||
|
||||
This class contains the shared logic between synchronous and asynchronous clients.
|
||||
It handles URL generation, request preparation, and response parsing.
|
||||
"""
|
||||
|
||||
def __init__(self, uri: Optional[str] = None, timeout: int = 30, is_local: bool = False):
|
||||
"""Initialize shared client logic"""
|
||||
self._timeout = timeout
|
||||
self._is_local = is_local
|
||||
|
||||
if uri:
|
||||
self._setup_auth(uri)
|
||||
else:
|
||||
self._base_url = "http://localhost:8000"
|
||||
self._auth_token = None
|
||||
|
||||
def _setup_auth(self, uri: str) -> None:
|
||||
"""Setup authentication from URI"""
|
||||
parsed = urlparse(uri)
|
||||
if not parsed.netloc:
|
||||
raise ValueError("Invalid URI format")
|
||||
|
||||
# Split host and auth parts
|
||||
auth, host = parsed.netloc.split("@")
|
||||
_, self._auth_token = auth.split(":")
|
||||
|
||||
# Set base URL
|
||||
self._base_url = f"{'http' if self._is_local else 'https'}://{host}"
|
||||
|
||||
# Basic token validation
|
||||
jwt.decode(self._auth_token, options={"verify_signature": False})
|
||||
|
||||
def _convert_rule(self, rule: RuleOrDict) -> Dict[str, Any]:
|
||||
"""Convert a rule to a dictionary format"""
|
||||
if hasattr(rule, "to_dict"):
|
||||
return rule.to_dict()
|
||||
return rule
|
||||
|
||||
def _get_url(self, endpoint: str) -> str:
|
||||
"""Get the full URL for an API endpoint"""
|
||||
return f"{self._base_url}/{endpoint.lstrip('/')}"
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get base headers for API requests"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
return headers
|
||||
|
||||
# Request preparation methods
|
||||
|
||||
def _prepare_ingest_text_request(
|
||||
self,
|
||||
content: str,
|
||||
filename: Optional[str],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
rules: Optional[List[RuleOrDict]],
|
||||
use_colpali: bool,
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for ingest_text endpoint"""
|
||||
rules_dict = [self._convert_rule(r) for r in (rules or [])]
|
||||
payload = {
|
||||
"content": content,
|
||||
"filename": filename,
|
||||
"metadata": metadata or {},
|
||||
"rules": rules_dict,
|
||||
"use_colpali": use_colpali,
|
||||
}
|
||||
if folder_name:
|
||||
payload["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
payload["end_user_id"] = end_user_id
|
||||
return payload
|
||||
|
||||
def _prepare_file_for_upload(
|
||||
self,
|
||||
file: Union[str, bytes, BinaryIO, Path],
|
||||
filename: Optional[str] = None,
|
||||
) -> Tuple[BinaryIO, str]:
|
||||
"""
|
||||
Process file input and return file object and filename.
|
||||
Handles different file input types (str, Path, bytes, file-like object).
|
||||
"""
|
||||
if isinstance(file, (str, Path)):
|
||||
file_path = Path(file)
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"File not found: {file}")
|
||||
filename = file_path.name if filename is None else filename
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
file_obj = BytesIO(content)
|
||||
elif isinstance(file, bytes):
|
||||
if filename is None:
|
||||
raise ValueError("filename is required when ingesting bytes")
|
||||
file_obj = BytesIO(file)
|
||||
else:
|
||||
if filename is None:
|
||||
raise ValueError("filename is required when ingesting file object")
|
||||
file_obj = file
|
||||
|
||||
return file_obj, filename
|
||||
|
||||
def _prepare_files_for_upload(
|
||||
self,
|
||||
files: List[Union[str, bytes, BinaryIO, Path]],
|
||||
) -> List[Tuple[str, Tuple[str, BinaryIO]]]:
|
||||
"""
|
||||
Process multiple files and return a list of file objects in the format
|
||||
expected by the API: [("files", (filename, file_obj)), ...]
|
||||
"""
|
||||
file_objects = []
|
||||
for file in files:
|
||||
if isinstance(file, (str, Path)):
|
||||
path = Path(file)
|
||||
file_objects.append(("files", (path.name, open(path, "rb"))))
|
||||
elif isinstance(file, bytes):
|
||||
file_objects.append(("files", ("file.bin", BytesIO(file))))
|
||||
else:
|
||||
file_objects.append(("files", (getattr(file, "name", "file.bin"), file)))
|
||||
|
||||
return file_objects
|
||||
|
||||
def _prepare_ingest_file_form_data(
|
||||
self,
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
rules: Optional[List[RuleOrDict]],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare form data for ingest_file endpoint"""
|
||||
form_data = {
|
||||
"metadata": json.dumps(metadata or {}),
|
||||
"rules": json.dumps([self._convert_rule(r) for r in (rules or [])]),
|
||||
}
|
||||
if folder_name:
|
||||
form_data["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
form_data["end_user_id"] = end_user_id
|
||||
return form_data
|
||||
|
||||
def _prepare_ingest_files_form_data(
|
||||
self,
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]],
|
||||
rules: Optional[List[RuleOrDict]],
|
||||
use_colpali: bool,
|
||||
parallel: bool,
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare form data for ingest_files endpoint"""
|
||||
# Convert rules appropriately based on whether it's a flat list or list of lists
|
||||
if rules:
|
||||
if all(isinstance(r, list) for r in rules):
|
||||
# List of lists - per-file rules
|
||||
converted_rules = [
|
||||
[self._convert_rule(r) for r in rule_list] for rule_list in rules
|
||||
]
|
||||
else:
|
||||
# Flat list - shared rules for all files
|
||||
converted_rules = [self._convert_rule(r) for r in rules]
|
||||
else:
|
||||
converted_rules = []
|
||||
|
||||
data = {
|
||||
"metadata": json.dumps(metadata or {}),
|
||||
"rules": json.dumps(converted_rules),
|
||||
"use_colpali": str(use_colpali).lower() if use_colpali is not None else None,
|
||||
"parallel": str(parallel).lower(),
|
||||
}
|
||||
|
||||
if folder_name:
|
||||
data["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
data["end_user_id"] = end_user_id
|
||||
|
||||
return data
|
||||
|
||||
def _prepare_query_request(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
k: int,
|
||||
min_score: float,
|
||||
max_tokens: Optional[int],
|
||||
temperature: Optional[float],
|
||||
use_colpali: bool,
|
||||
graph_name: Optional[str],
|
||||
hop_depth: int,
|
||||
include_paths: bool,
|
||||
prompt_overrides: Optional[Dict],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for query endpoint"""
|
||||
payload = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"use_colpali": use_colpali,
|
||||
"graph_name": graph_name,
|
||||
"hop_depth": hop_depth,
|
||||
"include_paths": include_paths,
|
||||
"prompt_overrides": prompt_overrides,
|
||||
}
|
||||
if folder_name:
|
||||
payload["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
payload["end_user_id"] = end_user_id
|
||||
# Filter out None values before sending
|
||||
return {k_p: v_p for k_p, v_p in payload.items() if v_p is not None}
|
||||
|
||||
def _prepare_retrieve_chunks_request(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
k: int,
|
||||
min_score: float,
|
||||
use_colpali: bool,
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for retrieve_chunks endpoint"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score,
|
||||
"use_colpali": use_colpali,
|
||||
}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
|
||||
def _prepare_retrieve_docs_request(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
k: int,
|
||||
min_score: float,
|
||||
use_colpali: bool,
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for retrieve_docs endpoint"""
|
||||
request = {
|
||||
"query": query,
|
||||
"filters": filters,
|
||||
"k": k,
|
||||
"min_score": min_score,
|
||||
"use_colpali": use_colpali,
|
||||
}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
|
||||
def _prepare_list_documents_request(
|
||||
self,
|
||||
skip: int,
|
||||
limit: int,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Prepare request for list_documents endpoint"""
|
||||
params = {
|
||||
"skip": skip,
|
||||
"limit": limit,
|
||||
}
|
||||
if folder_name:
|
||||
params["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
params["end_user_id"] = end_user_id
|
||||
data = filters or {}
|
||||
return params, data
|
||||
|
||||
def _prepare_batch_get_documents_request(
|
||||
self, document_ids: List[str], folder_name: Optional[str], end_user_id: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for batch_get_documents endpoint"""
|
||||
if folder_name or end_user_id:
|
||||
request = {"document_ids": document_ids}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
return document_ids # Return just IDs list if no scoping is needed
|
||||
|
||||
def _prepare_batch_get_chunks_request(
|
||||
self,
|
||||
sources: List[Union[ChunkSource, Dict[str, Any]]],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for batch_get_chunks endpoint"""
|
||||
source_dicts = []
|
||||
for source in sources:
|
||||
if isinstance(source, dict):
|
||||
source_dicts.append(source)
|
||||
else:
|
||||
source_dicts.append(source.model_dump())
|
||||
|
||||
if folder_name or end_user_id:
|
||||
request = {"sources": source_dicts}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
return source_dicts # Return just sources list if no scoping is needed
|
||||
|
||||
def _prepare_create_graph_request(
|
||||
self,
|
||||
name: str,
|
||||
filters: Optional[Dict[str, Any]],
|
||||
documents: Optional[List[str]],
|
||||
prompt_overrides: Optional[Union[GraphPromptOverrides, Dict[str, Any]]],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for create_graph endpoint"""
|
||||
# Convert prompt_overrides to dict if it's a model
|
||||
if prompt_overrides and isinstance(prompt_overrides, GraphPromptOverrides):
|
||||
prompt_overrides = prompt_overrides.model_dump(exclude_none=True)
|
||||
|
||||
request = {
|
||||
"name": name,
|
||||
"filters": filters,
|
||||
"documents": documents,
|
||||
"prompt_overrides": prompt_overrides,
|
||||
}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
|
||||
def _prepare_update_graph_request(
|
||||
self,
|
||||
name: str,
|
||||
additional_filters: Optional[Dict[str, Any]],
|
||||
additional_documents: Optional[List[str]],
|
||||
prompt_overrides: Optional[Union[GraphPromptOverrides, Dict[str, Any]]],
|
||||
folder_name: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request for update_graph endpoint"""
|
||||
# Convert prompt_overrides to dict if it's a model
|
||||
if prompt_overrides and isinstance(prompt_overrides, GraphPromptOverrides):
|
||||
prompt_overrides = prompt_overrides.model_dump(exclude_none=True)
|
||||
|
||||
request = {
|
||||
"additional_filters": additional_filters,
|
||||
"additional_documents": additional_documents,
|
||||
"prompt_overrides": prompt_overrides,
|
||||
}
|
||||
if folder_name:
|
||||
request["folder_name"] = folder_name
|
||||
if end_user_id:
|
||||
request["end_user_id"] = end_user_id
|
||||
return request
|
||||
|
||||
def _prepare_update_document_with_text_request(
|
||||
self,
|
||||
document_id: str,
|
||||
content: str,
|
||||
filename: Optional[str],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
rules: Optional[List],
|
||||
update_strategy: str,
|
||||
use_colpali: Optional[bool],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Prepare request for update_document_with_text endpoint"""
|
||||
request = IngestTextRequest(
|
||||
content=content,
|
||||
filename=filename,
|
||||
metadata=metadata or {},
|
||||
rules=[self._convert_rule(r) for r in (rules or [])],
|
||||
use_colpali=use_colpali if use_colpali is not None else True,
|
||||
)
|
||||
|
||||
params = {}
|
||||
if update_strategy != "add":
|
||||
params["update_strategy"] = update_strategy
|
||||
|
||||
return params, request.model_dump()
|
||||
|
||||
# Response parsing methods
|
||||
|
||||
def _parse_document_response(self, response_json: Dict[str, Any]) -> Document:
|
||||
"""Parse document response"""
|
||||
return Document(**response_json)
|
||||
|
||||
def _parse_completion_response(self, response_json: Dict[str, Any]) -> CompletionResponse:
|
||||
"""Parse completion response"""
|
||||
return CompletionResponse(**response_json)
|
||||
|
||||
def _parse_document_list_response(self, response_json: List[Dict[str, Any]]) -> List[Document]:
|
||||
"""Parse document list response"""
|
||||
docs = [Document(**doc) for doc in response_json]
|
||||
return docs
|
||||
|
||||
def _parse_document_result_list_response(
|
||||
self, response_json: List[Dict[str, Any]]
|
||||
) -> List[DocumentResult]:
|
||||
"""Parse document result list response"""
|
||||
return [DocumentResult(**r) for r in response_json]
|
||||
|
||||
def _parse_chunk_result_list_response(
|
||||
self, response_json: List[Dict[str, Any]]
|
||||
) -> List[FinalChunkResult]:
|
||||
"""Parse chunk result list response"""
|
||||
chunks = [ChunkResult(**r) for r in response_json]
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
content = chunk.content
|
||||
if chunk.metadata.get("is_image"):
|
||||
try:
|
||||
# Handle data URI format "data:image/png;base64,..."
|
||||
if content.startswith("data:"):
|
||||
# Extract the base64 part after the comma
|
||||
content = content.split(",", 1)[1]
|
||||
|
||||
# Now decode the base64 string
|
||||
image_bytes = base64.b64decode(content)
|
||||
content = Image.open(io.BytesIO(image_bytes))
|
||||
except Exception:
|
||||
# Fall back to using the content as text
|
||||
content = chunk.content
|
||||
|
||||
final_chunks.append(
|
||||
FinalChunkResult(
|
||||
content=content,
|
||||
score=chunk.score,
|
||||
document_id=chunk.document_id,
|
||||
chunk_number=chunk.chunk_number,
|
||||
metadata=chunk.metadata,
|
||||
content_type=chunk.content_type,
|
||||
filename=chunk.filename,
|
||||
download_url=chunk.download_url,
|
||||
)
|
||||
)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def _parse_graph_response(self, response_json: Dict[str, Any]) -> Graph:
|
||||
"""Parse graph response"""
|
||||
return Graph(**response_json)
|
||||
|
||||
def _parse_graph_list_response(self, response_json: List[Dict[str, Any]]) -> List[Graph]:
|
||||
"""Parse graph list response"""
|
||||
return [Graph(**graph) for graph in response_json]
|
File diff suppressed because it is too large
Load Diff
@ -21,10 +21,10 @@ class Document(BaseModel):
|
||||
default_factory=dict, description="Access control information"
|
||||
)
|
||||
chunk_ids: List[str] = Field(default_factory=list, description="IDs of document chunks")
|
||||
|
||||
|
||||
# Client reference for update methods
|
||||
_client = None
|
||||
|
||||
|
||||
def update_with_text(
|
||||
self,
|
||||
content: str,
|
||||
@ -36,7 +36,7 @@ class Document(BaseModel):
|
||||
) -> "Document":
|
||||
"""
|
||||
Update this document with new text content using the specified strategy.
|
||||
|
||||
|
||||
Args:
|
||||
content: The new content to add
|
||||
filename: Optional new filename for the document
|
||||
@ -44,13 +44,15 @@ class Document(BaseModel):
|
||||
rules: Optional list of rules to apply to the content
|
||||
update_strategy: Strategy for updating the document (currently only 'add' is supported)
|
||||
use_colpali: Whether to use multi-vector embedding
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Updated document metadata
|
||||
"""
|
||||
if self._client is None:
|
||||
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
|
||||
|
||||
raise ValueError(
|
||||
"Document instance not connected to a client. Use a document returned from a Morphik client method."
|
||||
)
|
||||
|
||||
return self._client.update_document_with_text(
|
||||
document_id=self.external_id,
|
||||
content=content,
|
||||
@ -58,9 +60,9 @@ class Document(BaseModel):
|
||||
metadata=metadata,
|
||||
rules=rules,
|
||||
update_strategy=update_strategy,
|
||||
use_colpali=use_colpali
|
||||
use_colpali=use_colpali,
|
||||
)
|
||||
|
||||
|
||||
def update_with_file(
|
||||
self,
|
||||
file: "Union[str, bytes, BinaryIO, Path]",
|
||||
@ -72,7 +74,7 @@ class Document(BaseModel):
|
||||
) -> "Document":
|
||||
"""
|
||||
Update this document with content from a file using the specified strategy.
|
||||
|
||||
|
||||
Args:
|
||||
file: File to add (path string, bytes, file object, or Path)
|
||||
filename: Name of the file
|
||||
@ -80,13 +82,15 @@ class Document(BaseModel):
|
||||
rules: Optional list of rules to apply to the content
|
||||
update_strategy: Strategy for updating the document (currently only 'add' is supported)
|
||||
use_colpali: Whether to use multi-vector embedding
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Updated document metadata
|
||||
"""
|
||||
if self._client is None:
|
||||
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
|
||||
|
||||
raise ValueError(
|
||||
"Document instance not connected to a client. Use a document returned from a Morphik client method."
|
||||
)
|
||||
|
||||
return self._client.update_document_with_file(
|
||||
document_id=self.external_id,
|
||||
file=file,
|
||||
@ -94,28 +98,29 @@ class Document(BaseModel):
|
||||
metadata=metadata,
|
||||
rules=rules,
|
||||
update_strategy=update_strategy,
|
||||
use_colpali=use_colpali
|
||||
use_colpali=use_colpali,
|
||||
)
|
||||
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
metadata: Dict[str, Any],
|
||||
) -> "Document":
|
||||
"""
|
||||
Update this document's metadata only.
|
||||
|
||||
|
||||
Args:
|
||||
metadata: Metadata to update
|
||||
|
||||
|
||||
Returns:
|
||||
Document: Updated document metadata
|
||||
"""
|
||||
if self._client is None:
|
||||
raise ValueError("Document instance not connected to a client. Use a document returned from a Morphik client method.")
|
||||
|
||||
raise ValueError(
|
||||
"Document instance not connected to a client. Use a document returned from a Morphik client method."
|
||||
)
|
||||
|
||||
return self._client.update_document_metadata(
|
||||
document_id=self.external_id,
|
||||
metadata=metadata
|
||||
document_id=self.external_id, metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
@ -159,7 +164,7 @@ class DocumentResult(BaseModel):
|
||||
|
||||
class ChunkSource(BaseModel):
|
||||
"""Source information for a chunk used in completion"""
|
||||
|
||||
|
||||
document_id: str = Field(..., description="ID of the source document")
|
||||
chunk_number: int = Field(..., description="Chunk number within the document")
|
||||
score: Optional[float] = Field(None, description="Relevance score")
|
||||
@ -194,7 +199,9 @@ class Entity(BaseModel):
|
||||
type: str = Field(..., description="Entity type")
|
||||
properties: Dict[str, Any] = Field(default_factory=dict, description="Entity properties")
|
||||
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
|
||||
chunk_sources: Dict[str, List[int]] = Field(default_factory=dict, description="Source chunk numbers by document ID")
|
||||
chunk_sources: Dict[str, List[int]] = Field(
|
||||
default_factory=dict, description="Source chunk numbers by document ID"
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
@ -213,7 +220,9 @@ class Relationship(BaseModel):
|
||||
target_id: str = Field(..., description="Target entity ID")
|
||||
type: str = Field(..., description="Relationship type")
|
||||
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
|
||||
chunk_sources: Dict[str, List[int]] = Field(default_factory=dict, description="Source chunk numbers by document ID")
|
||||
chunk_sources: Dict[str, List[int]] = Field(
|
||||
default_factory=dict, description="Source chunk numbers by document ID"
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
@ -230,10 +239,14 @@ class Graph(BaseModel):
|
||||
id: str = Field(..., description="Unique graph identifier")
|
||||
name: str = Field(..., description="Graph name")
|
||||
entities: List[Entity] = Field(default_factory=list, description="Entities in the graph")
|
||||
relationships: List[Relationship] = Field(default_factory=list, description="Relationships in the graph")
|
||||
relationships: List[Relationship] = Field(
|
||||
default_factory=list, description="Relationships in the graph"
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Graph metadata")
|
||||
document_ids: List[str] = Field(default_factory=list, description="Source document IDs")
|
||||
filters: Optional[Dict[str, Any]] = Field(None, description="Document filters used to create the graph")
|
||||
filters: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Document filters used to create the graph"
|
||||
)
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
owner: Dict[str, str] = Field(default_factory=dict, description="Graph owner information")
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "morphik"
|
||||
version = "0.1.0"
|
||||
version = "0.1.2"
|
||||
authors = [
|
||||
{ name = "Morphik", email = "founders@morphik.ai" },
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user