mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
pass all tests apart from querying
This commit is contained in:
parent
983a4ee854
commit
000887a4dc
28
core/api.py
28
core/api.py
@ -1,9 +1,8 @@
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Union, Dict, Set
|
||||
from typing import List, Optional, Union, Dict, Set
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Depends,
|
||||
@ -13,7 +12,7 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
|
||||
import logging
|
||||
from core.models.request import IngestTextRequest, QueryRequest
|
||||
from core.models.documents import (
|
||||
Document,
|
||||
@ -34,6 +33,7 @@ from core.services.uri_service import get_uri_service
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(title="DataBridge API")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
@ -119,7 +119,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/documents/text", response_model=Document)
|
||||
@app.post("/ingest/text", response_model=Document)
|
||||
async def ingest_text(
|
||||
request: IngestTextRequest,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
@ -127,20 +127,25 @@ async def ingest_text(
|
||||
"""Ingest a text document."""
|
||||
try:
|
||||
return await document_service.ingest_text(request, auth)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/documents/file", response_model=Document)
|
||||
@app.post("/ingest/file", response_model=Document)
|
||||
async def ingest_file(
|
||||
file: UploadFile = File(...),
|
||||
file: UploadFile,
|
||||
metadata: str = Form("{}"), # JSON string of metadata
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
try:
|
||||
metadata_dict = json.loads(metadata)
|
||||
return await document_service.ingest_file(file, metadata_dict, auth)
|
||||
doc = await document_service.ingest_file(file, metadata_dict, auth)
|
||||
return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(400, "Invalid metadata JSON")
|
||||
except Exception as e:
|
||||
@ -180,9 +185,12 @@ async def get_document(
|
||||
"""Get document by ID."""
|
||||
try:
|
||||
doc = await document_service.db.get_document(document_id, auth)
|
||||
logger.info(f"Found document: {doc}")
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
return doc
|
||||
except HTTPException as e:
|
||||
raise e # Return the HTTPException as is
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@ -193,9 +201,9 @@ auth_router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
@auth_router.post("/developer-token")
|
||||
async def create_developer_token(
|
||||
dev_id: str,
|
||||
app_id: str = None,
|
||||
app_id: Optional[str] = None,
|
||||
expiry_days: int = 30,
|
||||
permissions: Set[str] = None,
|
||||
permissions: Optional[Set[str]] = None,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
) -> Dict[str, str]:
|
||||
"""Create a developer access URI."""
|
||||
@ -221,7 +229,7 @@ async def create_developer_token(
|
||||
async def create_user_token(
|
||||
user_id: str,
|
||||
expiry_days: int = 30,
|
||||
permissions: Set[str] = None,
|
||||
permissions: Optional[Set[str]] = None,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
) -> Dict[str, str]:
|
||||
"""Create a user access URI."""
|
||||
|
@ -73,13 +73,16 @@ class MongoDatabase(BaseDatabase):
|
||||
access_filter
|
||||
]
|
||||
}
|
||||
logger.debug(f"Querying document with query: {query}")
|
||||
|
||||
doc_dict = await self.collection.find_one(query)
|
||||
logger.debug(f"Found document: {doc_dict}")
|
||||
return Document(**doc_dict) if doc_dict else None
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error retrieving document metadata: {str(e)}")
|
||||
return None
|
||||
raise e
|
||||
# return None
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
|
@ -10,9 +10,9 @@ class IngestTextRequest(BaseModel):
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""Query request model - remains unchanged"""
|
||||
query: str
|
||||
"""Query request model with validation"""
|
||||
query: str = Field(..., min_length=1)
|
||||
return_type: QueryReturnType = QueryReturnType.CHUNKS
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
k: int = 4
|
||||
min_score: float = 0.0
|
||||
k: int = Field(default=4, gt=0)
|
||||
min_score: float = Field(default=0.0)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import List, Union
|
||||
from fastapi import UploadFile
|
||||
from typing import List
|
||||
import io
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from unstructured.partition.auto import partition
|
||||
import logging
|
||||
@ -32,18 +32,12 @@ class UnstructuredAPIParser(BaseParser):
|
||||
logger.error(f"Failed to split text: {str(e)}")
|
||||
raise
|
||||
|
||||
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
|
||||
async def parse_file(self, file: bytes, content_type: str) -> List[str]:
|
||||
"""Parse file content using unstructured"""
|
||||
try:
|
||||
# Handle different file input types
|
||||
if isinstance(file, UploadFile):
|
||||
file_content = await file.read()
|
||||
else:
|
||||
file_content = file
|
||||
|
||||
# Parse with unstructured
|
||||
elements = partition(
|
||||
file=file_content,
|
||||
file=io.BytesIO(file),
|
||||
content_type=content_type,
|
||||
api_key=self.api_key
|
||||
)
|
||||
@ -58,4 +52,4 @@ class UnstructuredAPIParser(BaseParser):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse file: {str(e)}")
|
||||
raise
|
||||
raise e
|
||||
|
@ -45,6 +45,8 @@ class DocumentService:
|
||||
auth: AuthContext
|
||||
) -> Document:
|
||||
"""Ingest a text document."""
|
||||
if "write" not in auth.permissions:
|
||||
raise PermissionError("User does not have write permission")
|
||||
try:
|
||||
# 1. Create document record
|
||||
doc = Document(
|
||||
@ -90,7 +92,7 @@ class DocumentService:
|
||||
except Exception as e:
|
||||
logger.error(f"Text document ingestion failed: {str(e)}")
|
||||
# TODO: Clean up any stored data on failure
|
||||
raise
|
||||
raise e
|
||||
|
||||
async def ingest_file(
|
||||
self,
|
||||
@ -99,6 +101,8 @@ class DocumentService:
|
||||
auth: AuthContext
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
if "write" not in auth.permissions:
|
||||
raise PermissionError("User does not have write permission")
|
||||
try:
|
||||
# 1. Create document record
|
||||
doc = Document(
|
||||
@ -129,7 +133,7 @@ class DocumentService:
|
||||
"key": storage_info[1]
|
||||
}
|
||||
logger.info(
|
||||
f"Stored file in bucket {storage_info[0]} with key {storage_info[1]}"
|
||||
f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`"
|
||||
)
|
||||
|
||||
# 3. Parse content into chunks
|
||||
@ -152,7 +156,7 @@ class DocumentService:
|
||||
logger.info(f"Created {len(chunk_objects)} chunk objects")
|
||||
|
||||
# 6. Store everything
|
||||
await self._store_chunks_and_doc(chunk_objects, doc)
|
||||
doc.chunk_ids = await self._store_chunks_and_doc(chunk_objects, doc)
|
||||
logger.info(f"Successfully stored file document {doc.external_id}")
|
||||
|
||||
return doc
|
||||
@ -228,10 +232,11 @@ class DocumentService:
|
||||
self,
|
||||
chunk_objects: List[DocumentChunk],
|
||||
doc: Document
|
||||
) -> None:
|
||||
) -> List[str]:
|
||||
"""Helper to store chunks and document"""
|
||||
# Store chunks in vector store
|
||||
if not await self.vector_store.store_embeddings(chunk_objects):
|
||||
success, result = await self.vector_store.store_embeddings(chunk_objects)
|
||||
if not success:
|
||||
raise Exception("Failed to store chunk embeddings")
|
||||
logger.debug("Stored chunk embeddings in vector store")
|
||||
|
||||
@ -240,6 +245,8 @@ class DocumentService:
|
||||
raise Exception("Failed to store document metadata")
|
||||
logger.debug("Stored document metadata in database")
|
||||
|
||||
return [str(id) for id in result.inserted_ids]
|
||||
|
||||
async def _create_chunk_results(
|
||||
self,
|
||||
auth: AuthContext,
|
||||
|
@ -97,7 +97,7 @@ class S3Storage(BaseStorage):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading base64 content to S3: {e}")
|
||||
raise
|
||||
raise e
|
||||
|
||||
async def download_file(self, bucket: str, key: str) -> bytes:
|
||||
"""Download file from S3."""
|
||||
|
@ -1,15 +1,18 @@
|
||||
import base64
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import jwt
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
from typing import AsyncGenerator, Dict
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
from core.api import app, get_settings
|
||||
from core.models.auth import EntityType
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
import mimetypes
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Test configuration
|
||||
TEST_DATA_DIR = Path(__file__).parent / "test_data"
|
||||
@ -17,8 +20,18 @@ JWT_SECRET = "your-secret-key-for-signing-tokens"
|
||||
TEST_USER_ID = "test_user"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session"""
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
loop = policy.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_test_environment():
|
||||
def setup_test_environment(event_loop):
|
||||
"""Setup test environment and create test files"""
|
||||
# Create test data directory if it doesn't exist
|
||||
TEST_DATA_DIR.mkdir(exist_ok=True)
|
||||
@ -49,7 +62,7 @@ def create_test_token(
|
||||
expired: bool = False
|
||||
) -> str:
|
||||
"""Create a test JWT token"""
|
||||
if permissions is None:
|
||||
if not permissions:
|
||||
permissions = ["read", "write", "admin"]
|
||||
|
||||
payload = {
|
||||
@ -76,7 +89,7 @@ def create_auth_header(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_app() -> FastAPI:
|
||||
async def test_app(event_loop: asyncio.AbstractEventLoop) -> FastAPI:
|
||||
"""Create test FastAPI application"""
|
||||
# Configure test settings
|
||||
settings = get_settings()
|
||||
@ -85,21 +98,21 @@ async def test_app() -> FastAPI:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
async def client(test_app: FastAPI, event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create async test client"""
|
||||
async with AsyncClient(app=test_app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_text_document(client: AsyncClient):
|
||||
async def test_ingest_text_document(client: AsyncClient, content: str = "Test content for document ingestion"):
|
||||
"""Test ingesting a text document"""
|
||||
headers = create_auth_header()
|
||||
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
"/ingest/text",
|
||||
json={
|
||||
"content": "Test content for document ingestion",
|
||||
"content": content,
|
||||
"metadata": {"test": True, "type": "text"}
|
||||
},
|
||||
headers=headers
|
||||
@ -115,28 +128,25 @@ async def test_ingest_text_document(client: AsyncClient):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_file_document(client: AsyncClient):
|
||||
"""Test ingesting a file (PDF) document"""
|
||||
async def test_ingest_pdf(client: AsyncClient):
|
||||
"""Test ingesting a pdf"""
|
||||
headers = create_auth_header()
|
||||
pdf_path = TEST_DATA_DIR / "test.pdf"
|
||||
|
||||
if not pdf_path.exists():
|
||||
pytest.skip("Test PDF file not available")
|
||||
|
||||
content_type, _ = mimetypes.guess_type(pdf_path)
|
||||
if not content_type:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
# Create form data with file and metadata
|
||||
files = {
|
||||
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
|
||||
}
|
||||
data = {
|
||||
"metadata": json.dumps({"test": True, "type": "pdf"})
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
with open(pdf_path, "rb") as f:
|
||||
response = await client.post(
|
||||
"/ingest/file",
|
||||
files={"file": (pdf_path.name, f, content_type)},
|
||||
data={"metadata": json.dumps({"test": True, "type": "pdf"})},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@ -148,47 +158,61 @@ async def test_ingest_file_document(client: AsyncClient):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(client: AsyncClient):
|
||||
"""Test ingestion error cases"""
|
||||
async def test_ingest_invalid_text_request(client: AsyncClient):
|
||||
"""Test ingestion with invalid text request missing required content field"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid text request
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
"/ingest/text",
|
||||
json={
|
||||
"wrong_field": "Test content" # Missing required content field
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_file_request(client: AsyncClient):
|
||||
"""Test ingestion with invalid file request missing file"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid file request
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
"/ingest/file",
|
||||
files={}, # Missing file
|
||||
data={"metadata": "{}"},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_invalid_metadata(client: AsyncClient):
|
||||
"""Test ingestion with invalid metadata JSON"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid metadata JSON
|
||||
pdf_path = TEST_DATA_DIR / "test.pdf"
|
||||
if pdf_path.exists():
|
||||
files = {
|
||||
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
|
||||
}
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
"/ingest/file",
|
||||
files=files,
|
||||
data={"metadata": "invalid json"},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400 # Bad request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_oversized_content(client: AsyncClient):
|
||||
"""Test ingestion with oversized content"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test oversized content
|
||||
large_content = "x" * (10 * 1024 * 1024) # 10MB
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
"/ingest/text",
|
||||
json={
|
||||
"content": large_content,
|
||||
"metadata": {}
|
||||
@ -199,26 +223,34 @@ async def test_ingest_error_handling(client: AsyncClient):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_errors(client: AsyncClient):
|
||||
"""Test authentication error cases"""
|
||||
# Test missing auth header
|
||||
response = await client.post("/documents/text")
|
||||
async def test_auth_missing_header(client: AsyncClient):
|
||||
"""Test authentication with missing auth header"""
|
||||
response = await client.post("/ingest/text")
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test invalid token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_invalid_token(client: AsyncClient):
|
||||
"""Test authentication with invalid token"""
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.post("/documents/file", headers=headers)
|
||||
response = await client.post("/ingest/file", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test expired token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_expired_token(client: AsyncClient):
|
||||
"""Test authentication with expired token"""
|
||||
headers = create_auth_header(expired=True)
|
||||
response = await client.post("/documents/text", headers=headers)
|
||||
response = await client.post("/ingest/text", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test insufficient permissions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_insufficient_permissions(client: AsyncClient):
|
||||
"""Test authentication with insufficient permissions"""
|
||||
headers = create_auth_header(permissions=["read"])
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
"/ingest/text",
|
||||
json={
|
||||
"content": "Test content",
|
||||
"metadata": {}
|
||||
@ -227,7 +259,6 @@ async def test_auth_errors(client: AsyncClient):
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(client: AsyncClient):
|
||||
"""Test querying document chunks"""
|
||||
@ -244,6 +275,7 @@ async def test_query_chunks(client: AsyncClient):
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
logger.info(f"Query response: {response.json()}")
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
@ -312,15 +344,17 @@ async def test_get_document(client: AsyncClient):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(client: AsyncClient):
|
||||
"""Test error handling scenarios"""
|
||||
async def test_invalid_document_id(client: AsyncClient):
|
||||
"""Test error handling for invalid document ID"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid document ID
|
||||
response = await client.get("/documents/invalid_id", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid query parameters
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_query_params(client: AsyncClient):
|
||||
"""Test error handling for invalid query parameters"""
|
||||
headers = create_auth_header()
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
@ -329,16 +363,20 @@ async def test_error_handling(client: AsyncClient):
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test oversized content
|
||||
large_content = "x" * (10 * 1024 * 1024) # 10MB
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={
|
||||
"content": large_content,
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_query_oversized_content(client: AsyncClient):
|
||||
# """Test error handling for oversized content"""
|
||||
# headers = create_auth_header()
|
||||
# large_content = "x" * (10 * 1024 * 1024) # 10MB
|
||||
# response = await client.post(
|
||||
# "/documents",
|
||||
# json={
|
||||
# "content": large_content,
|
||||
# "content_type": "text/plain"
|
||||
# },
|
||||
# headers=headers
|
||||
# )
|
||||
# assert response.status_code == 400
|
||||
|
@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from core.models.auth import AuthContext
|
||||
from core.models.documents import DocumentChunk
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def store_embeddings(self, chunks: List[DocumentChunk]) -> bool:
|
||||
def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, Optional[Any]]:
|
||||
"""Store document chunks and their embeddings"""
|
||||
pass
|
||||
|
||||
|
@ -61,8 +61,8 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
if documents:
|
||||
# Use ordered=False to continue even if some inserts fail
|
||||
result = await self.collection.insert_many(documents, ordered=False)
|
||||
return len(result.inserted_ids) > 0
|
||||
return False
|
||||
return len(result.inserted_ids) > 0, result
|
||||
return False, None
|
||||
|
||||
except PyMongoError as e:
|
||||
logger.error(f"Error storing embeddings: {str(e)}")
|
||||
@ -79,13 +79,13 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
try:
|
||||
# Build access filter based on auth context
|
||||
access_filter = self._build_access_filter(auth)
|
||||
|
||||
# Add metadata filters if provided
|
||||
filter_query = access_filter
|
||||
if filters:
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
if metadata_filter:
|
||||
filter_query = {"$and": [access_filter, metadata_filter]}
|
||||
metadata_filter = self._build_metadata_filter(filters)
|
||||
filter_query = {"$and": [access_filter, metadata_filter]} if metadata_filter else access_filter
|
||||
|
||||
logger.debug(f"Query vector looks like: {query_embedding}")
|
||||
logger.debug(f"Filter query looks like: {filter_query}")
|
||||
logger.debug(f"K is: {k}")
|
||||
logger.debug(f"Index is: {self.index_name}")
|
||||
|
||||
# Vector search pipeline
|
||||
pipeline = [
|
||||
@ -94,7 +94,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
"index": self.index_name,
|
||||
"path": "embedding",
|
||||
"queryVector": query_embedding,
|
||||
"numCandidates": k * 10, # Get more candidates for better results
|
||||
"numCandidates": 150, # Get more candidates for better results
|
||||
"limit": k,
|
||||
"filter": filter_query
|
||||
}
|
||||
@ -113,6 +113,9 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
|
||||
# Execute search
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
chunk_list = [result async for result in cursor]
|
||||
logger.info(f"Found {len(chunk_list)} similar chunks")
|
||||
logger.info(f"Cursor: {chunk_list}")
|
||||
chunks = []
|
||||
|
||||
async for result in cursor:
|
||||
@ -155,6 +158,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
"""Build MongoDB filter for metadata fields."""
|
||||
if not filters:
|
||||
return {}
|
||||
return filters
|
||||
|
||||
metadata_filter = {}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user