pass all tests apart from querying

This commit is contained in:
Arnav Agrawal 2024-11-28 19:09:40 -05:00
parent 983a4ee854
commit 000887a4dc
9 changed files with 169 additions and 115 deletions

View File

@ -1,9 +1,8 @@
import json
from datetime import datetime, UTC
from typing import List, Union, Dict, Set
from typing import List, Optional, Union, Dict, Set
from fastapi import (
FastAPI,
File,
Form,
HTTPException,
Depends,
@ -13,7 +12,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import jwt
import logging
from core.models.request import IngestTextRequest, QueryRequest
from core.models.documents import (
Document,
@ -34,6 +33,7 @@ from core.services.uri_service import get_uri_service
# Initialize FastAPI app
app = FastAPI(title="DataBridge API")
logger = logging.getLogger(__name__)
# Add CORS middleware
app.add_middleware(
@ -119,7 +119,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
raise HTTPException(status_code=401, detail=str(e))
@app.post("/documents/text", response_model=Document)
@app.post("/ingest/text", response_model=Document)
async def ingest_text(
request: IngestTextRequest,
auth: AuthContext = Depends(verify_token)
@ -127,20 +127,25 @@ async def ingest_text(
"""Ingest a text document."""
try:
return await document_service.ingest_text(request, auth)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/documents/file", response_model=Document)
@app.post("/ingest/file", response_model=Document)
async def ingest_file(
file: UploadFile = File(...),
file: UploadFile,
metadata: str = Form("{}"), # JSON string of metadata
auth: AuthContext = Depends(verify_token)
) -> Document:
"""Ingest a file document."""
try:
metadata_dict = json.loads(metadata)
return await document_service.ingest_file(file, metadata_dict, auth)
doc = await document_service.ingest_file(file, metadata_dict, auth)
return doc # Should just send a success response, not sure why we're sending a document #TODO: discuss with bhau
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except json.JSONDecodeError:
raise HTTPException(400, "Invalid metadata JSON")
except Exception as e:
@ -180,9 +185,12 @@ async def get_document(
"""Get document by ID."""
try:
doc = await document_service.db.get_document(document_id, auth)
logger.info(f"Found document: {doc}")
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
return doc
except HTTPException as e:
raise e # Return the HTTPException as is
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@ -193,9 +201,9 @@ auth_router = APIRouter(prefix="/auth", tags=["auth"])
@auth_router.post("/developer-token")
async def create_developer_token(
dev_id: str,
app_id: str = None,
app_id: Optional[str] = None,
expiry_days: int = 30,
permissions: Set[str] = None,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a developer access URI."""
@ -221,7 +229,7 @@ async def create_developer_token(
async def create_user_token(
user_id: str,
expiry_days: int = 30,
permissions: Set[str] = None,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a user access URI."""

View File

@ -73,13 +73,16 @@ class MongoDatabase(BaseDatabase):
access_filter
]
}
logger.debug(f"Querying document with query: {query}")
doc_dict = await self.collection.find_one(query)
logger.debug(f"Found document: {doc_dict}")
return Document(**doc_dict) if doc_dict else None
except PyMongoError as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
return None
raise e
# return None
async def get_documents(
self,

View File

@ -10,9 +10,9 @@ class IngestTextRequest(BaseModel):
class QueryRequest(BaseModel):
"""Query request model - remains unchanged"""
query: str
"""Query request model with validation"""
query: str = Field(..., min_length=1)
return_type: QueryReturnType = QueryReturnType.CHUNKS
filters: Optional[Dict[str, Any]] = None
k: int = 4
min_score: float = 0.0
k: int = Field(default=4, gt=0)
min_score: float = Field(default=0.0)

View File

@ -1,5 +1,5 @@
from typing import List, Union
from fastapi import UploadFile
from typing import List
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from unstructured.partition.auto import partition
import logging
@ -32,18 +32,12 @@ class UnstructuredAPIParser(BaseParser):
logger.error(f"Failed to split text: {str(e)}")
raise
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
async def parse_file(self, file: bytes, content_type: str) -> List[str]:
"""Parse file content using unstructured"""
try:
# Handle different file input types
if isinstance(file, UploadFile):
file_content = await file.read()
else:
file_content = file
# Parse with unstructured
elements = partition(
file=file_content,
file=io.BytesIO(file),
content_type=content_type,
api_key=self.api_key
)
@ -58,4 +52,4 @@ class UnstructuredAPIParser(BaseParser):
except Exception as e:
logger.error(f"Failed to parse file: {str(e)}")
raise
raise e

View File

@ -45,6 +45,8 @@ class DocumentService:
auth: AuthContext
) -> Document:
"""Ingest a text document."""
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
try:
# 1. Create document record
doc = Document(
@ -90,7 +92,7 @@ class DocumentService:
except Exception as e:
logger.error(f"Text document ingestion failed: {str(e)}")
# TODO: Clean up any stored data on failure
raise
raise e
async def ingest_file(
self,
@ -99,6 +101,8 @@ class DocumentService:
auth: AuthContext
) -> Document:
"""Ingest a file document."""
if "write" not in auth.permissions:
raise PermissionError("User does not have write permission")
try:
# 1. Create document record
doc = Document(
@ -129,7 +133,7 @@ class DocumentService:
"key": storage_info[1]
}
logger.info(
f"Stored file in bucket {storage_info[0]} with key {storage_info[1]}"
f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`"
)
# 3. Parse content into chunks
@ -152,7 +156,7 @@ class DocumentService:
logger.info(f"Created {len(chunk_objects)} chunk objects")
# 6. Store everything
await self._store_chunks_and_doc(chunk_objects, doc)
doc.chunk_ids = await self._store_chunks_and_doc(chunk_objects, doc)
logger.info(f"Successfully stored file document {doc.external_id}")
return doc
@ -228,10 +232,11 @@ class DocumentService:
self,
chunk_objects: List[DocumentChunk],
doc: Document
) -> None:
) -> List[str]:
"""Helper to store chunks and document"""
# Store chunks in vector store
if not await self.vector_store.store_embeddings(chunk_objects):
success, result = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
logger.debug("Stored chunk embeddings in vector store")
@ -240,6 +245,8 @@ class DocumentService:
raise Exception("Failed to store document metadata")
logger.debug("Stored document metadata in database")
return [str(id) for id in result.inserted_ids]
async def _create_chunk_results(
self,
auth: AuthContext,

View File

@ -97,7 +97,7 @@ class S3Storage(BaseStorage):
except Exception as e:
logger.error(f"Error uploading base64 content to S3: {e}")
raise
raise e
async def download_file(self, bucket: str, key: str) -> bytes:
"""Download file from S3."""

View File

@ -1,15 +1,18 @@
import base64
import asyncio
import json
import pytest
from pathlib import Path
import jwt
from datetime import datetime, timedelta, UTC
from typing import AsyncGenerator, Dict, Any
from typing import AsyncGenerator, Dict
from httpx import AsyncClient
from fastapi import FastAPI
from core.api import app, get_settings
from core.models.auth import EntityType
from core.database.mongo_database import MongoDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
import mimetypes
import logging
logger = logging.getLogger(__name__)
# Test configuration
TEST_DATA_DIR = Path(__file__).parent / "test_data"
@ -17,8 +20,18 @@ JWT_SECRET = "your-secret-key-for-signing-tokens"
TEST_USER_ID = "test_user"
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session"""
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()
@pytest.fixture(scope="session", autouse=True)
def setup_test_environment():
def setup_test_environment(event_loop):
"""Setup test environment and create test files"""
# Create test data directory if it doesn't exist
TEST_DATA_DIR.mkdir(exist_ok=True)
@ -49,7 +62,7 @@ def create_test_token(
expired: bool = False
) -> str:
"""Create a test JWT token"""
if permissions is None:
if not permissions:
permissions = ["read", "write", "admin"]
payload = {
@ -76,7 +89,7 @@ def create_auth_header(
@pytest.fixture
async def test_app() -> FastAPI:
async def test_app(event_loop: asyncio.AbstractEventLoop) -> FastAPI:
"""Create test FastAPI application"""
# Configure test settings
settings = get_settings()
@ -85,21 +98,21 @@ async def test_app() -> FastAPI:
@pytest.fixture
async def client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
async def client(test_app: FastAPI, event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client"""
async with AsyncClient(app=test_app, base_url="http://test") as client:
yield client
@pytest.mark.asyncio
async def test_ingest_text_document(client: AsyncClient):
async def test_ingest_text_document(client: AsyncClient, content: str = "Test content for document ingestion"):
"""Test ingesting a text document"""
headers = create_auth_header()
response = await client.post(
"/documents/text",
"/ingest/text",
json={
"content": "Test content for document ingestion",
"content": content,
"metadata": {"test": True, "type": "text"}
},
headers=headers
@ -115,28 +128,25 @@ async def test_ingest_text_document(client: AsyncClient):
@pytest.mark.asyncio
async def test_ingest_file_document(client: AsyncClient):
"""Test ingesting a file (PDF) document"""
async def test_ingest_pdf(client: AsyncClient):
"""Test ingesting a pdf"""
headers = create_auth_header()
pdf_path = TEST_DATA_DIR / "test.pdf"
if not pdf_path.exists():
pytest.skip("Test PDF file not available")
content_type, _ = mimetypes.guess_type(pdf_path)
if not content_type:
content_type = "application/octet-stream"
# Create form data with file and metadata
files = {
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
}
data = {
"metadata": json.dumps({"test": True, "type": "pdf"})
}
response = await client.post(
"/documents/file",
files=files,
data=data,
headers=headers
)
with open(pdf_path, "rb") as f:
response = await client.post(
"/ingest/file",
files={"file": (pdf_path.name, f, content_type)},
data={"metadata": json.dumps({"test": True, "type": "pdf"})},
headers=headers
)
assert response.status_code == 200
data = response.json()
@ -148,47 +158,61 @@ async def test_ingest_file_document(client: AsyncClient):
@pytest.mark.asyncio
async def test_ingest_error_handling(client: AsyncClient):
"""Test ingestion error cases"""
async def test_ingest_invalid_text_request(client: AsyncClient):
"""Test ingestion with invalid text request missing required content field"""
headers = create_auth_header()
# Test invalid text request
response = await client.post(
"/documents/text",
"/ingest/text",
json={
"wrong_field": "Test content" # Missing required content field
},
headers=headers
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_ingest_invalid_file_request(client: AsyncClient):
"""Test ingestion with invalid file request missing file"""
headers = create_auth_header()
# Test invalid file request
response = await client.post(
"/documents/file",
"/ingest/file",
files={}, # Missing file
data={"metadata": "{}"},
headers=headers
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_ingest_invalid_metadata(client: AsyncClient):
"""Test ingestion with invalid metadata JSON"""
headers = create_auth_header()
# Test invalid metadata JSON
pdf_path = TEST_DATA_DIR / "test.pdf"
if pdf_path.exists():
files = {
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
}
response = await client.post(
"/documents/file",
"/ingest/file",
files=files,
data={"metadata": "invalid json"},
headers=headers
)
assert response.status_code == 400 # Bad request
@pytest.mark.asyncio
async def test_ingest_oversized_content(client: AsyncClient):
"""Test ingestion with oversized content"""
headers = create_auth_header()
# Test oversized content
large_content = "x" * (10 * 1024 * 1024) # 10MB
response = await client.post(
"/documents/text",
"/ingest/text",
json={
"content": large_content,
"metadata": {}
@ -199,26 +223,34 @@ async def test_ingest_error_handling(client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_errors(client: AsyncClient):
"""Test authentication error cases"""
# Test missing auth header
response = await client.post("/documents/text")
async def test_auth_missing_header(client: AsyncClient):
"""Test authentication with missing auth header"""
response = await client.post("/ingest/text")
assert response.status_code == 401
# Test invalid token
@pytest.mark.asyncio
async def test_auth_invalid_token(client: AsyncClient):
"""Test authentication with invalid token"""
headers = {"Authorization": "Bearer invalid_token"}
response = await client.post("/documents/file", headers=headers)
response = await client.post("/ingest/file", headers=headers)
assert response.status_code == 401
# Test expired token
@pytest.mark.asyncio
async def test_auth_expired_token(client: AsyncClient):
"""Test authentication with expired token"""
headers = create_auth_header(expired=True)
response = await client.post("/documents/text", headers=headers)
response = await client.post("/ingest/text", headers=headers)
assert response.status_code == 401
# Test insufficient permissions
@pytest.mark.asyncio
async def test_auth_insufficient_permissions(client: AsyncClient):
"""Test authentication with insufficient permissions"""
headers = create_auth_header(permissions=["read"])
response = await client.post(
"/documents/text",
"/ingest/text",
json={
"content": "Test content",
"metadata": {}
@ -227,7 +259,6 @@ async def test_auth_errors(client: AsyncClient):
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_query_chunks(client: AsyncClient):
"""Test querying document chunks"""
@ -244,6 +275,7 @@ async def test_query_chunks(client: AsyncClient):
},
headers=headers
)
logger.info(f"Query response: {response.json()}")
assert response.status_code == 200
results = response.json()
@ -312,15 +344,17 @@ async def test_get_document(client: AsyncClient):
@pytest.mark.asyncio
async def test_error_handling(client: AsyncClient):
"""Test error handling scenarios"""
async def test_invalid_document_id(client: AsyncClient):
"""Test error handling for invalid document ID"""
headers = create_auth_header()
# Test invalid document ID
response = await client.get("/documents/invalid_id", headers=headers)
assert response.status_code == 404
# Test invalid query parameters
@pytest.mark.asyncio
async def test_invalid_query_params(client: AsyncClient):
"""Test error handling for invalid query parameters"""
headers = create_auth_header()
response = await client.post(
"/query",
json={
@ -329,16 +363,20 @@ async def test_error_handling(client: AsyncClient):
},
headers=headers
)
assert response.status_code == 400
# Test oversized content
large_content = "x" * (10 * 1024 * 1024) # 10MB
response = await client.post(
"/documents",
json={
"content": large_content,
"content_type": "text/plain"
},
headers=headers
)
assert response.status_code == 400
assert response.status_code == 422
# @pytest.mark.asyncio
# async def test_query_oversized_content(client: AsyncClient):
# """Test error handling for oversized content"""
# headers = create_auth_header()
# large_content = "x" * (10 * 1024 * 1024) # 10MB
# response = await client.post(
# "/documents",
# json={
# "content": large_content,
# "content_type": "text/plain"
# },
# headers=headers
# )
# assert response.status_code == 400

View File

@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from core.models.auth import AuthContext
from core.models.documents import DocumentChunk
class BaseVectorStore(ABC):
@abstractmethod
def store_embeddings(self, chunks: List[DocumentChunk]) -> bool:
def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, Optional[Any]]:
"""Store document chunks and their embeddings"""
pass

View File

@ -61,8 +61,8 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
if documents:
# Use ordered=False to continue even if some inserts fail
result = await self.collection.insert_many(documents, ordered=False)
return len(result.inserted_ids) > 0
return False
return len(result.inserted_ids) > 0, result
return False, None
except PyMongoError as e:
logger.error(f"Error storing embeddings: {str(e)}")
@ -79,13 +79,13 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
try:
# Build access filter based on auth context
access_filter = self._build_access_filter(auth)
# Add metadata filters if provided
filter_query = access_filter
if filters:
metadata_filter = self._build_metadata_filter(filters)
if metadata_filter:
filter_query = {"$and": [access_filter, metadata_filter]}
metadata_filter = self._build_metadata_filter(filters)
filter_query = {"$and": [access_filter, metadata_filter]} if metadata_filter else access_filter
logger.debug(f"Query vector looks like: {query_embedding}")
logger.debug(f"Filter query looks like: {filter_query}")
logger.debug(f"K is: {k}")
logger.debug(f"Index is: {self.index_name}")
# Vector search pipeline
pipeline = [
@ -94,7 +94,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"index": self.index_name,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": k * 10, # Get more candidates for better results
"numCandidates": 150, # Get more candidates for better results
"limit": k,
"filter": filter_query
}
@ -113,6 +113,9 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
# Execute search
cursor = self.collection.aggregate(pipeline)
chunk_list = [result async for result in cursor]
logger.info(f"Found {len(chunk_list)} similar chunks")
logger.info(f"Cursor: {chunk_list}")
chunks = []
async for result in cursor:
@ -155,6 +158,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"""Build MongoDB filter for metadata fields."""
if not filters:
return {}
return filters
metadata_filter = {}