From 251e38828a22c8548747e3ccaa62f59ea0058163 Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Wed, 4 Dec 2024 20:26:14 -0500 Subject: [PATCH] clean up --- core/.env.test | 7 --- core/api.py | 72 ++-------------------- core/config.py | 35 ++++++----- core/database/mongo_database.py | 46 +++++++------- core/services/document_service.py | 5 +- core/services/uri_service.py | 46 ++++++-------- core/tests/conftest.py | 37 ----------- core/tests/integration/test_api.py | 80 +++++++++++++----------- core/vector_store/base_vector_store.py | 2 +- core/vector_store/mongo_vector_store.py | 81 +++++-------------------- 10 files changed, 129 insertions(+), 282 deletions(-) delete mode 100644 core/.env.test delete mode 100644 core/tests/conftest.py diff --git a/core/.env.test b/core/.env.test deleted file mode 100644 index ae63240..0000000 --- a/core/.env.test +++ /dev/null @@ -1,7 +0,0 @@ -JWT_SECRET_KEY=test-secret -MONGODB_URI=mongodb://localhost:27017/databridge_test -DATABRIDGE_TEST_URI=databridge://test_dev:your_test_token@localhost:8000 -DATABRIDGE_HOST=localhost:8000 -OPENAI_API_KEY=your_test_key -AWS_ACCESS_KEY=test -AWS_SECRET_ACCESS_KEY=test \ No newline at end of file diff --git a/core/api.py b/core/api.py index 783ebd9..607a6eb 100644 --- a/core/api.py +++ b/core/api.py @@ -1,13 +1,12 @@ import json from datetime import datetime, UTC -from typing import List, Optional, Union, Dict, Set +from typing import List, Union from fastapi import ( FastAPI, Form, HTTPException, Depends, Header, - APIRouter, UploadFile ) from fastapi.middleware.cors import CORSMiddleware @@ -28,7 +27,6 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore from core.storage.s3_storage import S3Storage from core.parser.unstructured_parser import UnstructuredAPIParser from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel -from core.services.uri_service import get_uri_service # Initialize FastAPI app @@ -98,14 +96,14 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext: status_code=401, detail="Invalid authorization header" ) - + token = authorization[7:] # Remove "Bearer " payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) - + if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC): raise HTTPException(status_code=401, detail="Token expired") @@ -143,7 +141,7 @@ async def ingest_file( try: metadata_dict = json.loads(metadata) doc = await document_service.ingest_file(file, metadata_dict, auth) - return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau + return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except json.JSONDecodeError: @@ -160,8 +158,6 @@ async def query_documents( """Query documents with specified return type.""" try: return await document_service.query(request, auth) - # except AttributeError as e: - # raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Query failed: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) @@ -193,64 +189,6 @@ async def get_document( raise HTTPException(status_code=404, detail="Document not found") return doc except HTTPException as e: - raise e # Return the HTTPException as is + raise e # Return the HTTPException as is except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - - -auth_router = APIRouter(prefix="/auth", tags=["auth"]) - - -@auth_router.post("/developer-token") -async def create_developer_token( - dev_id: str, - app_id: Optional[str] = None, - expiry_days: int = 30, - permissions: Optional[Set[str]] = None, - auth: AuthContext = Depends(verify_token) -) -> Dict[str, str]: - """Create a developer access URI.""" - # Verify requesting user has admin permissions - if "admin" not in auth.permissions: - raise HTTPException( - status_code=403, - detail="Admin permissions required" - ) - - uri_service = get_uri_service() - uri = uri_service.create_developer_uri( - dev_id=dev_id, - app_id=app_id, - expiry_days=expiry_days, - permissions=permissions - ) - - return {"uri": uri} - - -@auth_router.post("/user-token") -async def create_user_token( - user_id: str, - expiry_days: int = 30, - permissions: Optional[Set[str]] = None, - auth: AuthContext = Depends(verify_token) -) -> Dict[str, str]: - """Create a user access URI.""" - # Verify requesting user has admin permissions - if "admin" not in auth.permissions: - raise HTTPException( - status_code=403, - detail="Admin permissions required" - ) - - uri_service = get_uri_service() - uri = uri_service.create_user_uri( - user_id=user_id, - expiry_days=expiry_days, - permissions=permissions - ) - - return {"uri": uri} - -# Add to your main FastAPI app -app.include_router(auth_router) diff --git a/core/config.py b/core/config.py index bc8c571..99e68de 100644 --- a/core/config.py +++ b/core/config.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import Dict, Any from pydantic import Field from pydantic_settings import BaseSettings from functools import lru_cache @@ -6,45 +6,48 @@ from functools import lru_cache class Settings(BaseSettings): """DataBridge configuration settings.""" - + # MongoDB settings MONGODB_URI: str = Field(..., env="MONGODB_URI") DATABRIDGE_DB: str = Field(..., env="DATABRIDGE_DB") - + # Collection names DOCUMENTS_COLLECTION: str = Field("documents", env="DOCUMENTS_COLLECTION") CHUNKS_COLLECTION: str = Field("document_chunks", env="CHUNKS_COLLECTION") - + # Vector search settings VECTOR_INDEX_NAME: str = Field("vector_index", env="VECTOR_INDEX_NAME") - + # API Keys OPENAI_API_KEY: str = Field(..., env="OPENAI_API_KEY") UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY") - + # Optional API keys for alternative models - ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY") - COHERE_API_KEY: Optional[str] = Field(None, env="COHERE_API_KEY") - VOYAGE_API_KEY: Optional[str] = Field(None, env="VOYAGE_API_KEY") - + ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY") + COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY") + VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY") + # Model settings - EMBEDDING_MODEL: str = Field("text-embedding-3-small", env="EMBEDDING_MODEL") - + EMBEDDING_MODEL: str = Field( + "text-embedding-3-small", + env="EMBEDDING_MODEL" + ) + # Document processing settings CHUNK_SIZE: int = Field(1000, env="CHUNK_SIZE") CHUNK_OVERLAP: int = Field(200, env="CHUNK_OVERLAP") DEFAULT_K: int = Field(4, env="DEFAULT_K") - + # Storage settings AWS_ACCESS_KEY: str = Field(..., env="AWS_ACCESS_KEY") AWS_SECRET_ACCESS_KEY: str = Field(..., env="AWS_SECRET_ACCESS_KEY") AWS_REGION: str = Field("us-east-2", env="AWS_REGION") S3_BUCKET: str = Field("databridge-storage", env="S3_BUCKET") - + # Auth settings JWT_SECRET_KEY: str = Field(..., env="JWT_SECRET_KEY") JWT_ALGORITHM: str = Field("HS256", env="JWT_ALGORITHM") - + # Server settings HOST: str = Field("localhost", env="HOST") PORT: int = Field(8000, env="PORT") @@ -115,4 +118,4 @@ class Settings(BaseSettings): @lru_cache() def get_settings() -> Settings: """Get cached settings instance.""" - return Settings() \ No newline at end of file + return Settings() diff --git a/core/database/mongo_database.py b/core/database/mongo_database.py index ae486f4..f581d86 100644 --- a/core/database/mongo_database.py +++ b/core/database/mongo_database.py @@ -1,6 +1,6 @@ -from typing import List, Optional, Dict, Any -import logging from datetime import UTC, datetime +import logging +from typing import Dict, List, Optional, Any from motor.motor_asyncio import AsyncIOMotorClient from pymongo import ReturnDocument @@ -37,7 +37,7 @@ class MongoDatabase(BaseDatabase): await self.collection.create_index("access_control.writers") await self.collection.create_index("access_control.admins") await self.collection.create_index("system_metadata.created_at") - + logger.info("MongoDB indexes created successfully") return True except PyMongoError as e: @@ -48,14 +48,14 @@ class MongoDatabase(BaseDatabase): """Store document metadata.""" try: doc_dict = document.model_dump() - + # Ensure system metadata doc_dict["system_metadata"]["created_at"] = datetime.now(UTC) doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC) result = await self.collection.insert_one(doc_dict) return bool(result.inserted_id) - + except PyMongoError as e: logger.error(f"Error storing document metadata: {str(e)}") return False @@ -65,7 +65,7 @@ class MongoDatabase(BaseDatabase): try: # Build access filter access_filter = self._build_access_filter(auth) - + # Query document query = { "$and": [ @@ -78,11 +78,10 @@ class MongoDatabase(BaseDatabase): doc_dict = await self.collection.find_one(query) logger.debug(f"Found document: {doc_dict}") return Document(**doc_dict) if doc_dict else None - + except PyMongoError as e: logger.error(f"Error retrieving document metadata: {str(e)}") raise e - # return None async def get_documents( self, @@ -100,13 +99,13 @@ class MongoDatabase(BaseDatabase): # Execute paginated query cursor = self.collection.find(query).skip(skip).limit(limit) - + documents = [] async for doc_dict in cursor: documents.append(Document(**doc_dict)) - + return documents - + except PyMongoError as e: logger.error(f"Error listing documents: {str(e)}") return [] @@ -132,9 +131,9 @@ class MongoDatabase(BaseDatabase): {"$set": updates}, return_document=ReturnDocument.AFTER ) - + return bool(result) - + except PyMongoError as e: logger.error(f"Error updating document metadata: {str(e)}") return False @@ -148,7 +147,7 @@ class MongoDatabase(BaseDatabase): result = await self.collection.delete_one({"external_id": document_id}) return bool(result.deleted_count) - + except PyMongoError as e: logger.error(f"Error deleting document: {str(e)}") return False @@ -167,13 +166,13 @@ class MongoDatabase(BaseDatabase): # Get matching document IDs cursor = self.collection.find(query, {"external_id": 1}) - + document_ids = [] async for doc in cursor: document_ids.append(doc["external_id"]) - + return document_ids - + except PyMongoError as e: logger.error(f"Error finding documents: {str(e)}") return [] @@ -191,10 +190,11 @@ class MongoDatabase(BaseDatabase): return False access_control = doc.get("access_control", {}) - + # Check owner access owner = doc.get("owner", {}) - if (owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id): + if (owner.get("type") == auth.entity_type and + owner.get("id") == auth.entity_id): return True # Check permission-specific access @@ -203,13 +203,13 @@ class MongoDatabase(BaseDatabase): "write": "writers", "admin": "admins" } - + permission_set = permission_map.get(required_permission) if not permission_set: return False - + return auth.entity_id in access_control.get(permission_set, set()) - + except PyMongoError as e: logger.error(f"Error checking document access: {str(e)}") return False @@ -232,7 +232,7 @@ class MongoDatabase(BaseDatabase): ) return base_filter - + def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: """Build MongoDB filter for metadata.""" if not filters: diff --git a/core/services/document_service.py b/core/services/document_service.py index 1844b3f..8ee84e5 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -1,8 +1,7 @@ -import base64 -from collections import defaultdict -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Union import logging from fastapi import UploadFile +import base64 from core.database.base_database import BaseDatabase from core.embedding_model.base_embedding_model import BaseEmbeddingModel diff --git a/core/services/uri_service.py b/core/services/uri_service.py index 5b423ab..a487d03 100644 --- a/core/services/uri_service.py +++ b/core/services/uri_service.py @@ -1,24 +1,23 @@ -# core/auth/uri_service.py from typing import Optional, Set from datetime import datetime, timedelta, UTC import jwt -from functools import lru_cache from ..models.auth import EntityType, AuthContext from ..config import Settings +# Currently unused. Will be used for uri generation. class URIService: - """Service for creating and validating DataBridge URIs with authentication tokens""" - + """Service for creating and validating DataBridge URIs with authentication tokens.""" + def __init__(self, settings: Settings): self.secret_key = settings.JWT_SECRET_KEY self.algorithm = settings.JWT_ALGORITHM self.host = settings.HOST self.port = settings.PORT - + def create_developer_uri( - self, + self, dev_id: str, app_id: Optional[str] = None, expiry_days: int = 30, @@ -26,7 +25,7 @@ class URIService: ) -> str: """ Create URI for developer access. - + Args: dev_id: Developer ID app_id: Optional application ID for app-specific access @@ -39,21 +38,21 @@ class URIService: "permissions": list(permissions or {"read"}), "exp": datetime.now(UTC) + timedelta(days=expiry_days) } - + if app_id: payload["app_id"] = app_id - + token = jwt.encode( payload, self.secret_key, algorithm=self.algorithm ) - + # Construct the URI owner_id = f"{dev_id}.{app_id}" if app_id else dev_id host_port = f"{self.host}:{self.port}" if self.port != 80 else self.host return f"databridge://{owner_id}:{token}@{host_port}" - + def create_user_uri( self, user_id: str, @@ -62,9 +61,9 @@ class URIService: ) -> str: """ Create URI for end-user access. - + Args: - user_id: User ID + user_id: User ID expiry_days: Token validity in days permissions: Set of permissions to grant """ @@ -78,46 +77,37 @@ class URIService: self.secret_key, algorithm=self.algorithm ) - + host_port = f"{self.host}:{self.port}" if self.port != 80 else self.host return f"databridge://{user_id}:{token}@{host_port}" def validate_uri(self, uri: str) -> Optional[AuthContext]: """ Validate a DataBridge URI and return auth context if valid. - + Args: uri: DataBridge URI to validate - + Returns: AuthContext if valid, None otherwise """ try: # Extract token from URI token = uri.split("://")[1].split("@")[0].split(":")[1] - + # Decode and validate token payload = jwt.decode( token, self.secret_key, algorithms=[self.algorithm] ) - + return AuthContext( entity_type=EntityType(payload["type"]), entity_id=payload["entity_id"], app_id=payload.get("app_id"), permissions=set(payload.get("permissions", ["read"])) ) - + except (jwt.InvalidTokenError, IndexError, ValueError): return None - - -@lru_cache() -def get_uri_service(settings: Settings = None) -> URIService: - """Get cached URIService instance.""" - if settings is None: - from ..config import get_settings - settings = get_settings() - return URIService(settings) diff --git a/core/tests/conftest.py b/core/tests/conftest.py deleted file mode 100644 index d28d0f1..0000000 --- a/core/tests/conftest.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path -import sys -import pytest -from typing import Generator -import os -from dotenv import load_dotenv - -root_dir = Path(__file__).parent.parent.parent -sdk_path = str(root_dir / "sdks" / "python") -core_path = str(root_dir) - -sys.path.extend([sdk_path, core_path]) - -from core.config import get_settings -from databridge import DataBridge -# Load test environment variables -load_dotenv(".env.test") - - -@pytest.fixture(scope="session") -def settings(): - """Get test settings""" - return get_settings() - - -@pytest.fixture -async def db() -> Generator[DataBridge, None, None]: - """DataBridge client fixture""" - uri = os.getenv("DATABRIDGE_TEST_URI") - if not uri: - raise ValueError("DATABRIDGE_TEST_URI not set") - - client = DataBridge(uri) - try: - yield client - finally: - await client.close() diff --git a/core/tests/integration/test_api.py b/core/tests/integration/test_api.py index 5c17fea..7006166 100644 --- a/core/tests/integration/test_api.py +++ b/core/tests/integration/test_api.py @@ -35,23 +35,16 @@ def setup_test_environment(event_loop): """Setup test environment and create test files""" # Create test data directory if it doesn't exist TEST_DATA_DIR.mkdir(exist_ok=True) - + # Create a test text file text_file = TEST_DATA_DIR / "test.txt" if not text_file.exists(): text_file.write_text("This is a test document for DataBridge testing.") - + # Create a small test PDF if it doesn't exist pdf_file = TEST_DATA_DIR / "test.pdf" if not pdf_file.exists(): - # Create a minimal PDF for testing - try: - from reportlab.pdfgen import canvas - c = canvas.Canvas(str(pdf_file)) - c.drawString(100, 750, "Test PDF Document") - c.save() - except ImportError: - pytest.skip("reportlab not installed, skipping PDF tests") + pytest.skip("PDF file not available, skipping PDF tests") def create_test_token( @@ -64,7 +57,7 @@ def create_test_token( """Create a test JWT token""" if not permissions: permissions = ["read", "write", "admin"] - + payload = { "type": entity_type, "entity_id": entity_id, @@ -98,14 +91,20 @@ async def test_app(event_loop: asyncio.AbstractEventLoop) -> FastAPI: @pytest.fixture -async def client(test_app: FastAPI, event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[AsyncClient, None]: +async def client( + test_app: FastAPI, + event_loop: asyncio.AbstractEventLoop +) -> AsyncGenerator[AsyncClient, None]: """Create async test client""" async with AsyncClient(app=test_app, base_url="http://test") as client: yield client @pytest.mark.asyncio -async def test_ingest_text_document(client: AsyncClient, content: str = "Test content for document ingestion"): +async def test_ingest_text_document( + client: AsyncClient, + content: str = "Test content for document ingestion" +): """Test ingesting a text document""" headers = create_auth_header() @@ -132,14 +131,14 @@ async def test_ingest_pdf(client: AsyncClient): """Test ingesting a pdf""" headers = create_auth_header() pdf_path = TEST_DATA_DIR / "test.pdf" - + if not pdf_path.exists(): pytest.skip("Test PDF file not available") content_type, _ = mimetypes.guess_type(pdf_path) if not content_type: content_type = "application/octet-stream" - + with open(pdf_path, "rb") as f: response = await client.post( "/ingest/file", @@ -147,13 +146,13 @@ async def test_ingest_pdf(client: AsyncClient): data={"metadata": json.dumps({"test": True, "type": "pdf"})}, headers=headers ) - + assert response.status_code == 200 data = response.json() assert "external_id" in data assert data["content_type"] == "application/pdf" assert "storage_info" in data - + return data["external_id"] @@ -161,7 +160,7 @@ async def test_ingest_pdf(client: AsyncClient): async def test_ingest_invalid_text_request(client: AsyncClient): """Test ingestion with invalid text request missing required content field""" headers = create_auth_header() - + response = await client.post( "/ingest/text", json={ @@ -172,11 +171,11 @@ async def test_ingest_invalid_text_request(client: AsyncClient): assert response.status_code == 422 # Validation error -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_ingest_invalid_file_request(client: AsyncClient): """Test ingestion with invalid file request missing file""" headers = create_auth_header() - + response = await client.post( "/ingest/file", files={}, # Missing file @@ -190,14 +189,14 @@ async def test_ingest_invalid_file_request(client: AsyncClient): async def test_ingest_invalid_metadata(client: AsyncClient): """Test ingestion with invalid metadata JSON""" headers = create_auth_header() - + pdf_path = TEST_DATA_DIR / "test.pdf" if pdf_path.exists(): files = { "file": ("test.pdf", open(pdf_path, "rb"), "application/pdf") } response = await client.post( - "/ingest/file", + "/ingest/file", files=files, data={"metadata": "invalid json"}, headers=headers @@ -209,7 +208,7 @@ async def test_ingest_invalid_metadata(client: AsyncClient): async def test_ingest_oversized_content(client: AsyncClient): """Test ingestion with oversized content""" headers = create_auth_header() - + large_content = "x" * (10 * 1024 * 1024) # 10MB response = await client.post( "/ingest/text", @@ -229,7 +228,7 @@ async def test_auth_missing_header(client: AsyncClient): assert response.status_code == 401 -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_auth_invalid_token(client: AsyncClient): """Test authentication with invalid token""" headers = {"Authorization": "Bearer invalid_token"} @@ -264,12 +263,15 @@ async def test_auth_insufficient_permissions(client: AsyncClient): async def test_query_chunks(client: AsyncClient): """Test querying document chunks""" # First ingest a document to query - doc_id = await test_ingest_text_document(client, content="The quick brown fox jumps over the lazy dog") - + doc_id = await test_ingest_text_document( + client, + content="The quick brown fox jumps over the lazy dog" + ) + headers = create_auth_header() # Sleep to allow time for document to be indexed await asyncio.sleep(1) - + response = await client.post( "/query", json={ @@ -280,7 +282,7 @@ async def test_query_chunks(client: AsyncClient): headers=headers ) logger.info(f"Query response: {response.json()}") - + assert response.status_code == 200 results = list(response.json()) logger.info(f"Query results: {results}") @@ -293,8 +295,16 @@ async def test_query_chunks(client: AsyncClient): async def test_query_documents(client: AsyncClient): """Test querying for full documents""" # First ingest a document to query - doc_id = await test_ingest_text_document(client, content="Headaches can significantly impact daily life and wellbeing. Common triggers include stress, dehydration, and poor sleep habits. While over-the-counter pain relievers may provide temporary relief, it's important to identify and address the root causes. Maintaining good health through proper nutrition, regular exercise, and stress management can help prevent chronic headaches.") - + content = ( + "Headaches can significantly impact daily life and wellbeing. " + "Common triggers include stress, dehydration, and poor sleep habits. " + "While over-the-counter pain relievers may provide temporary relief, " + "it's important to identify and address the root causes. " + "Maintaining good health through proper nutrition, regular exercise, " + "and stress management can help prevent chronic headaches." + ) + doc_id = await test_ingest_text_document(client, content=content) + headers = create_auth_header() response = await client.post( "/query", @@ -305,7 +315,7 @@ async def test_query_documents(client: AsyncClient): }, headers=headers ) - + assert response.status_code == 200 results = list(response.json()) assert len(results) > 0 @@ -320,10 +330,10 @@ async def test_list_documents(client: AsyncClient): # First ingest some documents doc_id1 = await test_ingest_text_document(client) doc_id2 = await test_ingest_text_document(client) - + headers = create_auth_header() response = await client.get("/documents", headers=headers) - + assert response.status_code == 200 docs = response.json() assert len(docs) >= 2 @@ -337,10 +347,10 @@ async def test_get_document(client: AsyncClient): """Test getting a specific document""" # First ingest a document doc_id = await test_ingest_text_document(client) - + headers = create_auth_header() response = await client.get(f"/documents/{doc_id}", headers=headers) - + assert response.status_code == 200 doc = response.json() assert doc["external_id"] == doc_id diff --git a/core/vector_store/base_vector_store.py b/core/vector_store/base_vector_store.py index 9790c8d..f538226 100644 --- a/core/vector_store/base_vector_store.py +++ b/core/vector_store/base_vector_store.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from core.models.auth import AuthContext from core.models.documents import DocumentChunk diff --git a/core/vector_store/mongo_vector_store.py b/core/vector_store/mongo_vector_store.py index d7800c7..9dcb389 100644 --- a/core/vector_store/mongo_vector_store.py +++ b/core/vector_store/mongo_vector_store.py @@ -1,12 +1,10 @@ -import json -from typing import List, Dict, Any, Optional +from typing import List, Optional import logging from motor.motor_asyncio import AsyncIOMotorClient from pymongo.errors import PyMongoError from .base_vector_store import BaseVectorStore from core.models.documents import DocumentChunk -from core.models.auth import AuthContext, EntityType logger = logging.getLogger(__name__) @@ -33,10 +31,10 @@ class MongoDBAtlasVectorStore(BaseVectorStore): # Create basic indexes await self.collection.create_index("document_id") await self.collection.create_index("chunk_number") - + # Note: Vector search index must be created via Atlas UI or API # as it requires specific configuration - + logger.info("MongoDB vector store indexes initialized") return True except PyMongoError as e: @@ -55,13 +53,18 @@ class MongoDBAtlasVectorStore(BaseVectorStore): doc = chunk.model_dump() # Ensure we have required fields if not doc.get('embedding'): - logger.error(f"Missing embedding for chunk {chunk.document_id}-{chunk.chunk_number}") + logger.error( + f"Missing embedding for chunk " + f"{chunk.document_id}-{chunk.chunk_number}" + ) continue documents.append(doc) if documents: # Use ordered=False to continue even if some inserts fail - result = await self.collection.insert_many(documents, ordered=False) + result = await self.collection.insert_many( + documents, ordered=False + ) return len(result.inserted_ids) > 0, result return False, None @@ -77,7 +80,10 @@ class MongoDBAtlasVectorStore(BaseVectorStore): ) -> List[DocumentChunk]: """Find similar chunks using MongoDB Atlas Vector Search.""" try: - logger.debug(f"Searching in database {self.db.name} collection {self.collection.name}") + logger.debug( + f"Searching in database {self.db.name} " + f"collection {self.collection.name}" + ) logger.debug(f"Query vector looks like: {query_embedding}") logger.debug(f"Doc IDs: {doc_ids}") logger.debug(f"K is: {k}") @@ -90,7 +96,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore): "index": self.index_name, "path": "embedding", "queryVector": query_embedding, - "numCandidates": k*40, # Get more candidates for better results + "numCandidates": k*40, # Get more candidates "limit": k, "filter": {"document_id": {"$in": doc_ids}} if doc_ids else {} } @@ -110,7 +116,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore): # Execute search cursor = self.collection.aggregate(pipeline) chunks = [] - + async for result in cursor: chunk = DocumentChunk( document_id=result["document_id"], @@ -128,58 +134,3 @@ class MongoDBAtlasVectorStore(BaseVectorStore): logger.error(f"MongoDB error: {e._message}") logger.error(f"Error querying similar chunks: {str(e)}") raise e - - def _build_access_filter(self, auth: AuthContext) -> Dict[str, Any]: - """Build MongoDB filter for access control.""" - base_filter = { - "$or": [ - {"owner.id": auth.entity_id}, - {"access_control.readers": auth.entity_id}, - {"access_control.writers": auth.entity_id}, - {"access_control.admins": auth.entity_id} - ] - } - - if auth.entity_type == EntityType.DEVELOPER and auth.app_id: - # Add app-specific access for developers - base_filter["$or"].append( - {"access_control.app_access": auth.app_id} - ) - - return base_filter - - def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]: - """Build MongoDB filter for metadata fields.""" - if not filters: - return {} - return filters - - metadata_filter = {} - - for key, value in filters.items(): - metadata_key = f"metadata.{key}" - - if isinstance(value, (str, int, float, bool)): - metadata_filter[metadata_key] = value - elif isinstance(value, list): - metadata_filter[metadata_key] = {"$in": value} - elif isinstance(value, dict): - valid_ops = { - "gt": "$gt", - "gte": "$gte", - "lt": "$lt", - "lte": "$lte", - "ne": "$ne" - } - mongo_ops = {} - for op, val in value.items(): - if op not in valid_ops: - logger.warning(f"Skipping invalid operator: {op}") - continue - mongo_ops[valid_ops[op]] = val - if mongo_ops: - metadata_filter[metadata_key] = mongo_ops - else: - logger.warning(f"Skipping unsupported filter value type for key {key}") - - return metadata_filter