This commit is contained in:
Adityavardhan Agrawal 2024-12-04 20:26:14 -05:00
parent 46a7c45b4e
commit 251e38828a
10 changed files with 129 additions and 282 deletions

View File

@ -1,7 +0,0 @@
JWT_SECRET_KEY=test-secret
MONGODB_URI=mongodb://localhost:27017/databridge_test
DATABRIDGE_TEST_URI=databridge://test_dev:your_test_token@localhost:8000
DATABRIDGE_HOST=localhost:8000
OPENAI_API_KEY=your_test_key
AWS_ACCESS_KEY=test
AWS_SECRET_ACCESS_KEY=test

View File

@ -1,13 +1,12 @@
import json
from datetime import datetime, UTC
from typing import List, Optional, Union, Dict, Set
from typing import List, Union
from fastapi import (
FastAPI,
Form,
HTTPException,
Depends,
Header,
APIRouter,
UploadFile
)
from fastapi.middleware.cors import CORSMiddleware
@ -28,7 +27,6 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.storage.s3_storage import S3Storage
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from core.services.uri_service import get_uri_service
# Initialize FastAPI app
@ -98,14 +96,14 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
status_code=401,
detail="Invalid authorization header"
)
token = authorization[7:] # Remove "Bearer "
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired")
@ -143,7 +141,7 @@ async def ingest_file(
try:
metadata_dict = json.loads(metadata)
doc = await document_service.ingest_file(file, metadata_dict, auth)
return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau
return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except json.JSONDecodeError:
@ -160,8 +158,6 @@ async def query_documents(
"""Query documents with specified return type."""
try:
return await document_service.query(request, auth)
# except AttributeError as e:
# raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Query failed: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@ -193,64 +189,6 @@ async def get_document(
raise HTTPException(status_code=404, detail="Document not found")
return doc
except HTTPException as e:
raise e # Return the HTTPException as is
raise e # Return the HTTPException as is
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
auth_router = APIRouter(prefix="/auth", tags=["auth"])
@auth_router.post("/developer-token")
async def create_developer_token(
dev_id: str,
app_id: Optional[str] = None,
expiry_days: int = 30,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a developer access URI."""
# Verify requesting user has admin permissions
if "admin" not in auth.permissions:
raise HTTPException(
status_code=403,
detail="Admin permissions required"
)
uri_service = get_uri_service()
uri = uri_service.create_developer_uri(
dev_id=dev_id,
app_id=app_id,
expiry_days=expiry_days,
permissions=permissions
)
return {"uri": uri}
@auth_router.post("/user-token")
async def create_user_token(
user_id: str,
expiry_days: int = 30,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a user access URI."""
# Verify requesting user has admin permissions
if "admin" not in auth.permissions:
raise HTTPException(
status_code=403,
detail="Admin permissions required"
)
uri_service = get_uri_service()
uri = uri_service.create_user_uri(
user_id=user_id,
expiry_days=expiry_days,
permissions=permissions
)
return {"uri": uri}
# Add to your main FastAPI app
app.include_router(auth_router)

View File

@ -1,4 +1,4 @@
from typing import Optional, Dict, Any
from typing import Dict, Any
from pydantic import Field
from pydantic_settings import BaseSettings
from functools import lru_cache
@ -6,45 +6,48 @@ from functools import lru_cache
class Settings(BaseSettings):
"""DataBridge configuration settings."""
# MongoDB settings
MONGODB_URI: str = Field(..., env="MONGODB_URI")
DATABRIDGE_DB: str = Field(..., env="DATABRIDGE_DB")
# Collection names
DOCUMENTS_COLLECTION: str = Field("documents", env="DOCUMENTS_COLLECTION")
CHUNKS_COLLECTION: str = Field("document_chunks", env="CHUNKS_COLLECTION")
# Vector search settings
VECTOR_INDEX_NAME: str = Field("vector_index", env="VECTOR_INDEX_NAME")
# API Keys
OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY")
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
# Optional API keys for alternative models
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
COHERE_API_KEY: Optional[str] = Field(None, env="COHERE_API_KEY")
VOYAGE_API_KEY: Optional[str] = Field(None, env="VOYAGE_API_KEY")
ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY")
COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY")
VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY")
# Model settings
EMBEDDING_MODEL: str = Field("text-embedding-3-small", env="EMBEDDING_MODEL")
EMBEDDING_MODEL: str = Field(
"text-embedding-3-small",
env="EMBEDDING_MODEL"
)
# Document processing settings
CHUNK_SIZE: int = Field(1000, env="CHUNK_SIZE")
CHUNK_OVERLAP: int = Field(200, env="CHUNK_OVERLAP")
DEFAULT_K: int = Field(4, env="DEFAULT_K")
# Storage settings
AWS_ACCESS_KEY: str = Field(..., env="AWS_ACCESS_KEY")
AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY")
AWS_REGION: str = Field("us-east-2", env="AWS_REGION")
S3_BUCKET: str = Field("databridge-storage", env="S3_BUCKET")
# Auth settings
JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY")
JWT_ALGORITHM: str = Field("HS256", env="JWT_ALGORITHM")
# Server settings
HOST: str = Field("localhost", env="HOST")
PORT: int = Field(8000, env="PORT")
@ -115,4 +118,4 @@ class Settings(BaseSettings):
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
return Settings()

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Dict, Any
import logging
from datetime import UTC, datetime
import logging
from typing import Dict, List, Optional, Any
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ReturnDocument
@ -37,7 +37,7 @@ class MongoDatabase(BaseDatabase):
await self.collection.create_index("access_control.writers")
await self.collection.create_index("access_control.admins")
await self.collection.create_index("system_metadata.created_at")
logger.info("MongoDB indexes created successfully")
return True
except PyMongoError as e:
@ -48,14 +48,14 @@ class MongoDatabase(BaseDatabase):
"""Store document metadata."""
try:
doc_dict = document.model_dump()
# Ensure system metadata
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
result = await self.collection.insert_one(doc_dict)
return bool(result.inserted_id)
except PyMongoError as e:
logger.error(f"Error storing document metadata: {str(e)}")
return False
@ -65,7 +65,7 @@ class MongoDatabase(BaseDatabase):
try:
# Build access filter
access_filter = self._build_access_filter(auth)
# Query document
query = {
"$and": [
@ -78,11 +78,10 @@ class MongoDatabase(BaseDatabase):
doc_dict = await self.collection.find_one(query)
logger.debug(f"Found document: {doc_dict}")
return Document(**doc_dict) if doc_dict else None
except PyMongoError as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
raise e
# return None
async def get_documents(
self,
@ -100,13 +99,13 @@ class MongoDatabase(BaseDatabase):
# Execute paginated query
cursor = self.collection.find(query).skip(skip).limit(limit)
documents = []
async for doc_dict in cursor:
documents.append(Document(**doc_dict))
return documents
except PyMongoError as e:
logger.error(f"Error listing documents: {str(e)}")
return []
@ -132,9 +131,9 @@ class MongoDatabase(BaseDatabase):
{"$set": updates},
return_document=ReturnDocument.AFTER
)
return bool(result)
except PyMongoError as e:
logger.error(f"Error updating document metadata: {str(e)}")
return False
@ -148,7 +147,7 @@ class MongoDatabase(BaseDatabase):
result = await self.collection.delete_one({"external_id": document_id})
return bool(result.deleted_count)
except PyMongoError as e:
logger.error(f"Error deleting document: {str(e)}")
return False
@ -167,13 +166,13 @@ class MongoDatabase(BaseDatabase):
# Get matching document IDs
cursor = self.collection.find(query, {"external_id": 1})
document_ids = []
async for doc in cursor:
document_ids.append(doc["external_id"])
return document_ids
except PyMongoError as e:
logger.error(f"Error finding documents: {str(e)}")
return []
@ -191,10 +190,11 @@ class MongoDatabase(BaseDatabase):
return False
access_control = doc.get("access_control", {})
# Check owner access
owner = doc.get("owner", {})
if (owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id):
if (owner.get("type") == auth.entity_type and
owner.get("id") == auth.entity_id):
return True
# Check permission-specific access
@ -203,13 +203,13 @@ class MongoDatabase(BaseDatabase):
"write": "writers",
"admin": "admins"
}
permission_set = permission_map.get(required_permission)
if not permission_set:
return False
return auth.entity_id in access_control.get(permission_set, set())
except PyMongoError as e:
logger.error(f"Error checking document access: {str(e)}")
return False
@ -232,7 +232,7 @@ class MongoDatabase(BaseDatabase):
)
return base_filter
def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Build MongoDB filter for metadata."""
if not filters:

View File

@ -1,8 +1,7 @@
import base64
from collections import defaultdict
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union
import logging
from fastapi import UploadFile
import base64
from core.database.base_database import BaseDatabase
from core.embedding_model.base_embedding_model import BaseEmbeddingModel

View File

@ -1,24 +1,23 @@
# core/auth/uri_service.py
from typing import Optional, Set
from datetime import datetime, timedelta, UTC
import jwt
from functools import lru_cache
from ..models.auth import EntityType, AuthContext
from ..config import Settings
# Currently unused. Will be used for uri generation.
class URIService:
"""Service for creating and validating DataBridge URIs with authentication tokens"""
"""Service for creating and validating DataBridge URIs with authentication tokens."""
def __init__(self, settings: Settings):
self.secret_key = settings.JWT_SECRET_KEY
self.algorithm = settings.JWT_ALGORITHM
self.host = settings.HOST
self.port = settings.PORT
def create_developer_uri(
self,
self,
dev_id: str,
app_id: Optional[str] = None,
expiry_days: int = 30,
@ -26,7 +25,7 @@ class URIService:
) -> str:
"""
Create URI for developer access.
Args:
dev_id: Developer ID
app_id: Optional application ID for app-specific access
@ -39,21 +38,21 @@ class URIService:
"permissions": list(permissions or {"read"}),
"exp": datetime.now(UTC) + timedelta(days=expiry_days)
}
if app_id:
payload["app_id"] = app_id
token = jwt.encode(
payload,
self.secret_key,
algorithm=self.algorithm
)
# Construct the URI
owner_id = f"{dev_id}.{app_id}" if app_id else dev_id
host_port = f"{self.host}:{self.port}" if self.port != 80 else self.host
return f"databridge://{owner_id}:{token}@{host_port}"
def create_user_uri(
self,
user_id: str,
@ -62,9 +61,9 @@ class URIService:
) -> str:
"""
Create URI for end-user access.
Args:
user_id: User ID
user_id: User ID
expiry_days: Token validity in days
permissions: Set of permissions to grant
"""
@ -78,46 +77,37 @@ class URIService:
self.secret_key,
algorithm=self.algorithm
)
host_port = f"{self.host}:{self.port}" if self.port != 80 else self.host
return f"databridge://{user_id}:{token}@{host_port}"
def validate_uri(self, uri: str) -> Optional[AuthContext]:
"""
Validate a DataBridge URI and return auth context if valid.
Args:
uri: DataBridge URI to validate
Returns:
AuthContext if valid, None otherwise
"""
try:
# Extract token from URI
token = uri.split("://")[1].split("@")[0].split(":")[1]
# Decode and validate token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
return AuthContext(
entity_type=EntityType(payload["type"]),
entity_id=payload["entity_id"],
app_id=payload.get("app_id"),
permissions=set(payload.get("permissions", ["read"]))
)
except (jwt.InvalidTokenError, IndexError, ValueError):
return None
@lru_cache()
def get_uri_service(settings: Settings = None) -> URIService:
"""Get cached URIService instance."""
if settings is None:
from ..config import get_settings
settings = get_settings()
return URIService(settings)

View File

@ -1,37 +0,0 @@
from pathlib import Path
import sys
import pytest
from typing import Generator
import os
from dotenv import load_dotenv
root_dir = Path(__file__).parent.parent.parent
sdk_path = str(root_dir / "sdks" / "python")
core_path = str(root_dir)
sys.path.extend([sdk_path, core_path])
from core.config import get_settings
from databridge import DataBridge
# Load test environment variables
load_dotenv(".env.test")
@pytest.fixture(scope="session")
def settings():
"""Get test settings"""
return get_settings()
@pytest.fixture
async def db() -> Generator[DataBridge, None, None]:
"""DataBridge client fixture"""
uri = os.getenv("DATABRIDGE_TEST_URI")
if not uri:
raise ValueError("DATABRIDGE_TEST_URI not set")
client = DataBridge(uri)
try:
yield client
finally:
await client.close()

View File

@ -35,23 +35,16 @@ def setup_test_environment(event_loop):
"""Setup test environment and create test files"""
# Create test data directory if it doesn't exist
TEST_DATA_DIR.mkdir(exist_ok=True)
# Create a test text file
text_file = TEST_DATA_DIR / "test.txt"
if not text_file.exists():
text_file.write_text("This is a test document for DataBridge testing.")
# Create a small test PDF if it doesn't exist
pdf_file = TEST_DATA_DIR / "test.pdf"
if not pdf_file.exists():
# Create a minimal PDF for testing
try:
from reportlab.pdfgen import canvas
c = canvas.Canvas(str(pdf_file))
c.drawString(100, 750, "Test PDF Document")
c.save()
except ImportError:
pytest.skip("reportlab not installed, skipping PDF tests")
pytest.skip("PDF file not available, skipping PDF tests")
def create_test_token(
@ -64,7 +57,7 @@ def create_test_token(
"""Create a test JWT token"""
if not permissions:
permissions = ["read", "write", "admin"]
payload = {
"type": entity_type,
"entity_id": entity_id,
@ -98,14 +91,20 @@ async def test_app(event_loop: asyncio.AbstractEventLoop) -> FastAPI:
@pytest.fixture
async def client(test_app: FastAPI, event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[AsyncClient, None]:
async def client(
test_app: FastAPI,
event_loop: asyncio.AbstractEventLoop
) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client"""
async with AsyncClient(app=test_app, base_url="http://test") as client:
yield client
@pytest.mark.asyncio
async def test_ingest_text_document(client: AsyncClient, content: str = "Test content for document ingestion"):
async def test_ingest_text_document(
client: AsyncClient,
content: str = "Test content for document ingestion"
):
"""Test ingesting a text document"""
headers = create_auth_header()
@ -132,14 +131,14 @@ async def test_ingest_pdf(client: AsyncClient):
"""Test ingesting a pdf"""
headers = create_auth_header()
pdf_path = TEST_DATA_DIR / "test.pdf"
if not pdf_path.exists():
pytest.skip("Test PDF file not available")
content_type, _ = mimetypes.guess_type(pdf_path)
if not content_type:
content_type = "application/octet-stream"
with open(pdf_path, "rb") as f:
response = await client.post(
"/ingest/file",
@ -147,13 +146,13 @@ async def test_ingest_pdf(client: AsyncClient):
data={"metadata": json.dumps({"test": True, "type": "pdf"})},
headers=headers
)
assert response.status_code == 200
data = response.json()
assert "external_id" in data
assert data["content_type"] == "application/pdf"
assert "storage_info" in data
return data["external_id"]
@ -161,7 +160,7 @@ async def test_ingest_pdf(client: AsyncClient):
async def test_ingest_invalid_text_request(client: AsyncClient):
"""Test ingestion with invalid text request missing required content field"""
headers = create_auth_header()
response = await client.post(
"/ingest/text",
json={
@ -172,11 +171,11 @@ async def test_ingest_invalid_text_request(client: AsyncClient):
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_ingest_invalid_file_request(client: AsyncClient):
"""Test ingestion with invalid file request missing file"""
headers = create_auth_header()
response = await client.post(
"/ingest/file",
files={}, # Missing file
@ -190,14 +189,14 @@ async def test_ingest_invalid_file_request(client: AsyncClient):
async def test_ingest_invalid_metadata(client: AsyncClient):
"""Test ingestion with invalid metadata JSON"""
headers = create_auth_header()
pdf_path = TEST_DATA_DIR / "test.pdf"
if pdf_path.exists():
files = {
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
}
response = await client.post(
"/ingest/file",
"/ingest/file",
files=files,
data={"metadata": "invalid json"},
headers=headers
@ -209,7 +208,7 @@ async def test_ingest_invalid_metadata(client: AsyncClient):
async def test_ingest_oversized_content(client: AsyncClient):
"""Test ingestion with oversized content"""
headers = create_auth_header()
large_content = "x" * (10 * 1024 * 1024) # 10MB
response = await client.post(
"/ingest/text",
@ -229,7 +228,7 @@ async def test_auth_missing_header(client: AsyncClient):
assert response.status_code == 401
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_auth_invalid_token(client: AsyncClient):
"""Test authentication with invalid token"""
headers = {"Authorization": "Bearer invalid_token"}
@ -264,12 +263,15 @@ async def test_auth_insufficient_permissions(client: AsyncClient):
async def test_query_chunks(client: AsyncClient):
"""Test querying document chunks"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client, content="The quick brown fox jumps over the lazy dog")
doc_id = await test_ingest_text_document(
client,
content="The quick brown fox jumps over the lazy dog"
)
headers = create_auth_header()
# Sleep to allow time for document to be indexed
await asyncio.sleep(1)
response = await client.post(
"/query",
json={
@ -280,7 +282,7 @@ async def test_query_chunks(client: AsyncClient):
headers=headers
)
logger.info(f"Query response: {response.json()}")
assert response.status_code == 200
results = list(response.json())
logger.info(f"Query results: {results}")
@ -293,8 +295,16 @@ async def test_query_chunks(client: AsyncClient):
async def test_query_documents(client: AsyncClient):
"""Test querying for full documents"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client, content="Headaches can significantly impact daily life and wellbeing. Common triggers include stress, dehydration, and poor sleep habits. While over-the-counter pain relievers may provide temporary relief, it's important to identify and address the root causes. Maintaining good health through proper nutrition, regular exercise, and stress management can help prevent chronic headaches.")
content = (
"Headaches can significantly impact daily life and wellbeing. "
"Common triggers include stress, dehydration, and poor sleep habits. "
"While over-the-counter pain relievers may provide temporary relief, "
"it's important to identify and address the root causes. "
"Maintaining good health through proper nutrition, regular exercise, "
"and stress management can help prevent chronic headaches."
)
doc_id = await test_ingest_text_document(client, content=content)
headers = create_auth_header()
response = await client.post(
"/query",
@ -305,7 +315,7 @@ async def test_query_documents(client: AsyncClient):
},
headers=headers
)
assert response.status_code == 200
results = list(response.json())
assert len(results) > 0
@ -320,10 +330,10 @@ async def test_list_documents(client: AsyncClient):
# First ingest some documents
doc_id1 = await test_ingest_text_document(client)
doc_id2 = await test_ingest_text_document(client)
headers = create_auth_header()
response = await client.get("/documents", headers=headers)
assert response.status_code == 200
docs = response.json()
assert len(docs) >= 2
@ -337,10 +347,10 @@ async def test_get_document(client: AsyncClient):
"""Test getting a specific document"""
# First ingest a document
doc_id = await test_ingest_text_document(client)
headers = create_auth_header()
response = await client.get(f"/documents/{doc_id}", headers=headers)
assert response.status_code == 200
doc = response.json()
assert doc["external_id"] == doc_id

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from core.models.auth import AuthContext
from core.models.documents import DocumentChunk

View File

@ -1,12 +1,10 @@
import json
from typing import List, Dict, Any, Optional
from typing import List, Optional
import logging
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import PyMongoError
from .base_vector_store import BaseVectorStore
from core.models.documents import DocumentChunk
from core.models.auth import AuthContext, EntityType
logger = logging.getLogger(__name__)
@ -33,10 +31,10 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
# Create basic indexes
await self.collection.create_index("document_id")
await self.collection.create_index("chunk_number")
# Note: Vector search index must be created via Atlas UI or API
# as it requires specific configuration
logger.info("MongoDB vector store indexes initialized")
return True
except PyMongoError as e:
@ -55,13 +53,18 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
doc = chunk.model_dump()
# Ensure we have required fields
if not doc.get('embedding'):
logger.error(f"Missing embedding for chunk {chunk.document_id}-{chunk.chunk_number}")
logger.error(
f"Missing embedding for chunk "
f"{chunk.document_id}-{chunk.chunk_number}"
)
continue
documents.append(doc)
if documents:
# Use ordered=False to continue even if some inserts fail
result = await self.collection.insert_many(documents, ordered=False)
result = await self.collection.insert_many(
documents, ordered=False
)
return len(result.inserted_ids) > 0, result
return False, None
@ -77,7 +80,10 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
) -> List[DocumentChunk]:
"""Find similar chunks using MongoDB Atlas Vector Search."""
try:
logger.debug(f"Searching in database {self.db.name} collection {self.collection.name}")
logger.debug(
f"Searching in database {self.db.name} "
f"collection {self.collection.name}"
)
logger.debug(f"Query vector looks like: {query_embedding}")
logger.debug(f"Doc IDs: {doc_ids}")
logger.debug(f"K is: {k}")
@ -90,7 +96,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"index": self.index_name,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": k*40, # Get more candidates for better results
"numCandidates": k*40, # Get more candidates
"limit": k,
"filter": {"document_id": {"$in": doc_ids}} if doc_ids else {}
}
@ -110,7 +116,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
# Execute search
cursor = self.collection.aggregate(pipeline)
chunks = []
async for result in cursor:
chunk = DocumentChunk(
document_id=result["document_id"],
@ -128,58 +134,3 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
logger.error(f"MongoDB error: {e._message}")
logger.error(f"Error querying similar chunks: {str(e)}")
raise e
def _build_access_filter(self, auth: AuthContext) -> Dict[str, Any]:
"""Build MongoDB filter for access control."""
base_filter = {
"$or": [
{"owner.id": auth.entity_id},
{"access_control.readers": auth.entity_id},
{"access_control.writers": auth.entity_id},
{"access_control.admins": auth.entity_id}
]
}
if auth.entity_type == EntityType.DEVELOPER and auth.app_id:
# Add app-specific access for developers
base_filter["$or"].append(
{"access_control.app_access": auth.app_id}
)
return base_filter
def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Build MongoDB filter for metadata fields."""
if not filters:
return {}
return filters
metadata_filter = {}
for key, value in filters.items():
metadata_key = f"metadata.{key}"
if isinstance(value, (str, int, float, bool)):
metadata_filter[metadata_key] = value
elif isinstance(value, list):
metadata_filter[metadata_key] = {"$in": value}
elif isinstance(value, dict):
valid_ops = {
"gt": "$gt",
"gte": "$gte",
"lt": "$lt",
"lte": "$lte",
"ne": "$ne"
}
mongo_ops = {}
for op, val in value.items():
if op not in valid_ops:
logger.warning(f"Skipping invalid operator: {op}")
continue
mongo_ops[valid_ops[op]] = val
if mongo_ops:
metadata_filter[metadata_key] = mongo_ops
else:
logger.warning(f"Skipping unsupported filter value type for key {key}")
return metadata_filter