separate text and doc ingestion pathways

This commit is contained in:
Adityavardhan Agrawal 2024-11-24 14:29:25 -05:00
parent 1f6551d464
commit 983a4ee854
10 changed files with 591 additions and 283 deletions

View File

@ -1,26 +1,35 @@
import json
from datetime import datetime, UTC
from typing import List, Union, Dict, Set
from fastapi import FastAPI, HTTPException, Depends, Header, APIRouter
from fastapi import (
FastAPI,
File,
Form,
HTTPException,
Depends,
Header,
APIRouter,
UploadFile
)
from fastapi.middleware.cors import CORSMiddleware
import jwt
from core.models.request import IngestRequest, QueryRequest
from .models.documents import (
from core.models.request import IngestTextRequest, QueryRequest
from core.models.documents import (
Document,
DocumentResult,
ChunkResult,
EntityType
)
from .models.auth import AuthContext
from .services.document_service import DocumentService
from .config import get_settings
from .database.mongo_database import MongoDatabase
from .vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from .storage.s3_storage import S3Storage
from .parser.unstructured_parser import UnstructuredAPIParser
from .embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from .services.uri_service import get_uri_service
from core.models.auth import AuthContext
from core.services.document_service import DocumentService
from core.config import get_settings
from core.database.mongo_database import MongoDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.storage.s3_storage import S3Storage
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from core.services.uri_service import get_uri_service
# Initialize FastAPI app
@ -75,8 +84,14 @@ document_service = DocumentService(
)
async def verify_token(authorization: str = Header(...)) -> AuthContext:
async def verify_token(authorization: str = Header(None)) -> AuthContext:
"""Verify JWT Bearer token."""
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"}
)
try:
if not authorization.startswith("Bearer "):
raise HTTPException(
@ -104,15 +119,30 @@ async def verify_token(authorization: str = Header(...)) -> AuthContext:
raise HTTPException(status_code=401, detail=str(e))
# API endpoints
@app.post("/documents", response_model=Document)
async def ingest_document(
request: IngestRequest,
@app.post("/documents/text", response_model=Document)
async def ingest_text(
request: IngestTextRequest,
auth: AuthContext = Depends(verify_token)
) -> Document:
"""Ingest a new document."""
"""Ingest a text document."""
try:
return await document_service.ingest_document(request, auth)
return await document_service.ingest_text(request, auth)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/documents/file", response_model=Document)
async def ingest_file(
file: UploadFile = File(...),
metadata: str = Form("{}"), # JSON string of metadata
auth: AuthContext = Depends(verify_token)
) -> Document:
"""Ingest a file document."""
try:
metadata_dict = json.loads(metadata)
return await document_service.ingest_file(file, metadata_dict, auth)
except json.JSONDecodeError:
raise HTTPException(400, "Invalid metadata JSON")
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Dict, Any
import logging
from datetime import datetime
from datetime import UTC, datetime
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ReturnDocument
@ -20,7 +20,7 @@ class MongoDatabase(BaseDatabase):
self,
uri: str,
db_name: str,
collection_name: str = "documents"
collection_name: str,
):
"""Initialize MongoDB connection for document storage."""
self.client = AsyncIOMotorClient(uri)
@ -32,7 +32,7 @@ class MongoDatabase(BaseDatabase):
try:
# Create indexes for common queries
await self.collection.create_index("external_id", unique=True)
await self.collection.create_index("access_control.owner.id")
await self.collection.create_index("owner.id")
await self.collection.create_index("access_control.readers")
await self.collection.create_index("access_control.writers")
await self.collection.create_index("access_control.admins")
@ -50,8 +50,8 @@ class MongoDatabase(BaseDatabase):
doc_dict = document.model_dump()
# Ensure system metadata
doc_dict["system_metadata"]["created_at"] = datetime.utcnow()
doc_dict["system_metadata"]["updated_at"] = datetime.utcnow()
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
result = await self.collection.insert_one(doc_dict)
return bool(result.inserted_id)
@ -190,9 +190,8 @@ class MongoDatabase(BaseDatabase):
access_control = doc.get("access_control", {})
# Check owner access
owner = access_control.get("owner", {})
if (owner.get("type") == auth.entity_type and
owner.get("id") == auth.entity_id):
owner = doc.get("owner", {})
if (owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id):
return True
# Check permission-specific access
@ -216,7 +215,7 @@ class MongoDatabase(BaseDatabase):
"""Build MongoDB filter for access control."""
base_filter = {
"$or": [
{"access_control.owner.id": auth.entity_id},
{"owner.id": auth.entity_id},
{"access_control.readers": auth.entity_id},
{"access_control.writers": auth.entity_id},
{"access_control.admins": auth.entity_id}

View File

@ -1,6 +1,6 @@
from typing import Dict, Any, List, Optional, Literal
from enum import Enum
from datetime import datetime
from datetime import UTC, datetime
from pydantic import BaseModel, Field, field_validator
import uuid
@ -18,23 +18,23 @@ class QueryReturnType(str, Enum):
class Document(BaseModel):
"""Represents a document stored in MongoDB documents collection"""
external_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
owner: Dict[str, str]
content_type: str
filename: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
storage_info: Dict[str, str] = Field(default_factory=dict) # s3_bucket, s3_key
storage_info: Dict[str, str] = Field(default_factory=dict)
system_metadata: Dict[str, Any] = Field(
default_factory=lambda: {
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow(),
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
"version": 1
}
)
access_control: Dict[str, Any] = Field(
access_control: Dict[str, List[str]] = Field(
default_factory=lambda: {
"owner": None,
"readers": set(),
"writers": set(),
"admins": set()
"readers": [],
"writers": [],
"admins": []
}
)
chunk_ids: List[str] = Field(default_factory=list)

View File

@ -1,18 +1,16 @@
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from .documents import QueryReturnType
class IngestRequest(BaseModel):
class IngestTextRequest(BaseModel):
"""Request model for text ingestion"""
content: str
# TODO: We should infer this, not request it
content_type: str
metadata: Dict[str, Any] = Field(default_factory=dict)
filename: Optional[str] = None
class QueryRequest(BaseModel):
"""Query request model - remains unchanged"""
query: str
return_type: QueryReturnType = QueryReturnType.CHUNKS
filters: Optional[Dict[str, Any]] = None

View File

@ -1,9 +1,16 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from typing import List, Union
from fastapi import UploadFile
class BaseParser(ABC):
"""Base class for document parsing"""
@abstractmethod
def parse(self, content: str, metadata: Dict[str, Any]) -> List[str]:
"""Parse content into chunks"""
async def split_text(self, text: str) -> List[str]:
"""Split plain text into chunks"""
pass
@abstractmethod
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
"""Parse file content into text chunks"""
pass

View File

@ -1,13 +1,11 @@
from typing import Dict, Any, List
from .base_parser import BaseParser
from unstructured.partition.auto import partition
from typing import List, Union
from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
import tempfile
import base64
import nltk
from unstructured.partition.auto import partition
import logging
from .base_parser import BaseParser
logger = logging.getLogger(__name__)
@ -17,20 +15,8 @@ class UnstructuredAPIParser(BaseParser):
api_key: str,
chunk_size: int = 1000,
chunk_overlap: int = 200,
api_url: str = "https://api.unstructuredapp.io"
):
self.api_key = api_key
self.api_url = api_url
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# Download required NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
logger.info("Downloading NLTK punkt tokenizer...")
nltk.download('punkt', quiet=True)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@ -38,63 +24,38 @@ class UnstructuredAPIParser(BaseParser):
separators=["\n\n", "\n", ". ", " ", ""]
)
def parse(self, content: str, metadata: Dict[str, Any]) -> List[str]:
"""Parse content using Unstructured API and split into chunks."""
async def split_text(self, text: str) -> List[str]:
"""Split plain text into chunks"""
try:
# Create temporary file for content
with tempfile.NamedTemporaryFile(delete=False, suffix=self._get_file_extension(metadata)) as temp_file:
if metadata.get("is_base64", False):
try:
decoded_content = base64.b64decode(content)
except Exception as e:
logger.error(f"Failed to decode base64 content: {str(e)}")
# If base64 decode fails, try treating as plain text
decoded_content = content.encode('utf-8')
else:
decoded_content = content.encode('utf-8')
temp_file.write(decoded_content)
temp_file_path = temp_file.name
return self.text_splitter.split_text(text)
except Exception as e:
logger.error(f"Failed to split text: {str(e)}")
raise
try:
# Use Unstructured API for parsing
elements = partition(
filename=temp_file_path,
api_key=self.api_key,
api_url=self.api_url,
partition_via_api=True
)
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
"""Parse file content using unstructured"""
try:
# Handle different file input types
if isinstance(file, UploadFile):
file_content = await file.read()
else:
file_content = file
# Combine elements and split into chunks
full_text = "\n\n".join(str(element) for element in elements)
chunks = self.text_splitter.split_text(full_text)
# Parse with unstructured
elements = partition(
file=file_content,
content_type=content_type,
api_key=self.api_key
)
if not chunks:
# If no chunks were created, use the full text as a single chunk
logger.warning("No chunks created, using full text as single chunk")
return [full_text] if full_text.strip() else []
# Extract text from elements
chunks = []
for element in elements:
if hasattr(element, 'text') and element.text:
chunks.append(element.text.strip())
return chunks
finally:
# Clean up temporary file
try:
os.unlink(temp_file_path)
except Exception as e:
logger.error(f"Failed to delete temporary file: {str(e)}")
return chunks
except Exception as e:
logger.error(f"Error parsing document: {str(e)}")
raise Exception(f"Error parsing document: {str(e)}")
def _get_file_extension(self, metadata: Dict[str, Any]) -> str:
"""Get appropriate file extension based on content type."""
content_type_mapping = {
'application/pdf': '.pdf',
'application/msword': '.doc',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'image/jpeg': '.jpg',
'image/png': '.png',
'text/plain': '.txt',
'text/html': '.html'
}
return content_type_mapping.get(metadata.get('content_type'), '.txt')
logger.error(f"Failed to parse file: {str(e)}")
raise

View File

@ -1,15 +1,21 @@
import base64
from collections import defaultdict
from typing import Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional
import logging
from core.api import IngestRequest, QueryRequest
from fastapi import UploadFile
from core.database.base_database import BaseDatabase
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
from core.models.request import IngestTextRequest, QueryRequest
from core.parser.base_parser import BaseParser
from core.storage.base_storage import BaseStorage
from core.vector_store.base_vector_store import BaseVectorStore
from ..models.documents import (
Document, DocumentChunk, ChunkResult, DocumentContent, DocumentResult,
Document,
DocumentChunk,
ChunkResult,
DocumentContent,
DocumentResult,
QueryReturnType
)
from ..models.auth import AuthContext
@ -33,73 +39,128 @@ class DocumentService:
self.parser = parser
self.embedding_model = embedding_model
async def ingest_document(
async def ingest_text(
self,
request: IngestRequest,
request: IngestTextRequest,
auth: AuthContext
) -> Document:
"""Ingest a new document with chunks."""
"""Ingest a text document."""
try:
# 1. Create document record
doc = Document(
content_type=request.content_type,
filename=request.filename,
content_type="text/plain",
metadata=request.metadata,
owner={
"type": auth.entity_type,
"id": auth.entity_id
},
access_control={
"owner": {
"type": auth.entity_type,
"id": auth.entity_id
},
"readers": {auth.entity_id},
"readers": [auth.entity_id],
"writers": {auth.entity_id},
"admins": {auth.entity_id}
}
)
logger.info(f"Created text document record with ID {doc.external_id}")
# 2. Store file in storage if it's not text
if request.content_type != "text/plain":
storage_info = await self.storage.upload_from_base64(
request.content,
doc.external_id,
request.content_type
)
doc.storage_info = {
"bucket": storage_info[0],
"key": storage_info[1]
}
# 2. Parse content into chunks
chunks = await self.parser.split_text(request.content)
if not chunks:
raise ValueError("No content chunks extracted from text")
logger.info(f"Split text into {len(chunks)} chunks")
# 3. Parse content into chunks
chunks = await self.parser.parse(request.content)
# 4. Generate embeddings for chunks
# 3. Generate embeddings for chunks
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
logger.info(f"Generated {len(embeddings)} embeddings")
# 5. Create and store chunks with embeddings
chunk_objects = []
for i, (content, embedding) in enumerate(zip(chunks, embeddings)):
chunk = DocumentChunk(
document_id=doc.external_id,
content=content,
embedding=embedding,
chunk_number=i,
metadata=doc.metadata # Inherit document metadata
)
chunk_objects.append(chunk)
# 4. Create and store chunk objects
chunk_objects = self._create_chunk_objects(
doc.external_id,
chunks,
embeddings,
doc.metadata
)
logger.info(f"Created {len(chunk_objects)} chunk objects")
# 6. Store chunks in vector store
success = await self.vector_store.store_embeddings(chunk_objects)
if not success:
raise Exception("Failed to store chunk embeddings")
# 7. Store document metadata
if not await self.db.store_document(doc):
raise Exception("Failed to store document metadata")
# 5. Store everything
await self._store_chunks_and_doc(chunk_objects, doc)
logger.info(f"Successfully stored text document {doc.external_id}")
return doc
except Exception as e:
logger.error(f"Text document ingestion failed: {str(e)}")
# TODO: Clean up any stored data on failure
raise Exception(f"Document ingestion failed: {str(e)}")
raise
async def ingest_file(
self,
file: UploadFile,
metadata: Dict[str, Any],
auth: AuthContext
) -> Document:
"""Ingest a file document."""
try:
# 1. Create document record
doc = Document(
content_type=file.content_type,
filename=file.filename,
metadata=metadata,
owner={
"type": auth.entity_type,
"id": auth.entity_id
},
access_control={
"readers": [auth.entity_id],
"writers": {auth.entity_id},
"admins": {auth.entity_id}
}
)
logger.info(f"Created file document record with ID {doc.external_id}")
# 2. Read and store file
file_content = await file.read()
storage_info = await self.storage.upload_from_base64(
base64.b64encode(file_content).decode(),
doc.external_id,
file.content_type
)
doc.storage_info = {
"bucket": storage_info[0],
"key": storage_info[1]
}
logger.info(
f"Stored file in bucket {storage_info[0]} with key {storage_info[1]}"
)
# 3. Parse content into chunks
chunks = await self.parser.parse_file(file_content, file.content_type)
if not chunks:
raise ValueError("No content chunks extracted from file")
logger.info(f"Parsed file into {len(chunks)} chunks")
# 4. Generate embeddings for chunks
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
logger.info(f"Generated {len(embeddings)} embeddings")
# 5. Create and store chunk objects
chunk_objects = self._create_chunk_objects(
doc.external_id,
chunks,
embeddings,
doc.metadata
)
logger.info(f"Created {len(chunk_objects)} chunk objects")
# 6. Store everything
await self._store_chunks_and_doc(chunk_objects, doc)
logger.info(f"Successfully stored file document {doc.external_id}")
return doc
except Exception as e:
logger.error(f"File document ingestion failed: {str(e)}")
# TODO: Clean up any stored data on failure
raise
async def query(
self,
@ -109,12 +170,17 @@ class DocumentService:
"""Query documents with specified return type."""
try:
# 1. Get embedding for query
query_embedding = await self.embedding_model.embed_for_query(request.query)
query_embedding = await self.embedding_model.embed_for_query(
request.query
)
logger.info("Generated query embedding")
# 2. Find authorized documents
doc_ids = await self.db.find_documents(auth, request.filters)
if not doc_ids:
logger.info("No authorized documents found")
return []
logger.info(f"Found {len(doc_ids)} authorized documents")
# 3. Search chunks with vector similarity
chunks = await self.vector_store.query_similar(
@ -123,25 +189,71 @@ class DocumentService:
auth=auth,
filters={"document_id": {"$in": doc_ids}}
)
logger.info(f"Found {len(chunks)} similar chunks")
# 4. Return results in requested format
if request.return_type == QueryReturnType.CHUNKS:
return await self._create_chunk_results(auth, chunks)
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results")
return results
else:
return await self._create_document_results(auth, chunks)
results = await self._create_document_results(auth, chunks)
logger.info(f"Returning {len(results)} document results")
return results
except Exception as e:
logger.error(f"Query failed: {str(e)}")
raise e
raise
async def _create_chunk_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[ChunkResult]:
def _create_chunk_objects(
self,
doc_id: str,
chunks: List[str],
embeddings: List[List[float]],
metadata: Dict[str, Any]
) -> List[DocumentChunk]:
"""Helper to create chunk objects"""
return [
DocumentChunk(
document_id=doc_id,
content=content,
embedding=embedding,
chunk_number=i,
metadata=metadata
)
for i, (content, embedding) in enumerate(zip(chunks, embeddings))
]
async def _store_chunks_and_doc(
self,
chunk_objects: List[DocumentChunk],
doc: Document
) -> None:
"""Helper to store chunks and document"""
# Store chunks in vector store
if not await self.vector_store.store_embeddings(chunk_objects):
raise Exception("Failed to store chunk embeddings")
logger.debug("Stored chunk embeddings in vector store")
# Store document metadata
if not await self.db.store_document(doc):
raise Exception("Failed to store document metadata")
logger.debug("Stored document metadata in database")
async def _create_chunk_results(
self,
auth: AuthContext,
chunks: List[DocumentChunk]
) -> List[ChunkResult]:
"""Create ChunkResult objects with document metadata."""
results = []
for chunk in chunks:
# Get document metadata
doc = await self.db.get_document(chunk.document_id, auth)
if not doc:
logger.warning(f"Document {chunk.document_id} not found")
continue
logger.debug(f"Retrieved metadata for document {chunk.document_id}")
# Generate download URL if needed
download_url = None
@ -150,6 +262,9 @@ class DocumentService:
doc.storage_info["bucket"],
doc.storage_info["key"]
)
logger.debug(
f"Generated download URL for document {chunk.document_id}"
)
results.append(ChunkResult(
content=chunk.content,
@ -162,22 +277,31 @@ class DocumentService:
download_url=download_url
))
logger.info(f"Created {len(results)} chunk results")
return results
async def _create_document_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[DocumentResult]:
async def _create_document_results(
self,
auth: AuthContext,
chunks: List[DocumentChunk]
) -> List[DocumentResult]:
"""Group chunks by document and create DocumentResult objects."""
# Group chunks by document and get highest scoring chunk per doc
doc_chunks: Dict[str, DocumentChunk] = {}
for chunk in chunks:
if chunk.document_id not in doc_chunks or chunk.score > doc_chunks[chunk.document_id].score:
if (chunk.document_id not in doc_chunks or
chunk.score > doc_chunks[chunk.document_id].score):
doc_chunks[chunk.document_id] = chunk
logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
results = []
for doc_id, chunk in doc_chunks.items():
# Get document metadata
doc = await self.db.get_document(doc_id, auth)
if not doc:
logger.warning(f"Document {doc_id} not found")
continue
logger.debug(f"Retrieved metadata for document {doc_id}")
# Create DocumentContent based on content type
if doc.content_type == "text/plain":
@ -186,6 +310,7 @@ class DocumentService:
value=chunk.content,
filename=None
)
logger.debug(f"Created text content for document {doc_id}")
else:
# Generate download URL for file types
download_url = await self.storage.get_download_url(
@ -197,6 +322,7 @@ class DocumentService:
value=download_url,
filename=doc.filename
)
logger.debug(f"Created URL content for document {doc_id}")
results.append(DocumentResult(
score=chunk.score,
@ -205,4 +331,5 @@ class DocumentService:
content=content
))
logger.info(f"Created {len(results)} document results")
return results

View File

@ -1,158 +1,344 @@
import base64
import pytest
from fastapi.testclient import TestClient
from core.api import app
from pathlib import Path
import jwt
from datetime import datetime, UTC, timedelta
from datetime import datetime, timedelta, UTC
from typing import AsyncGenerator, Dict, Any
from httpx import AsyncClient
from fastapi import FastAPI
from core.api import app, get_settings
from core.models.auth import EntityType
from core.database.mongo_database import MongoDatabase
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
# Test configuration
TEST_DATA_DIR = Path(__file__).parent / "test_data"
JWT_SECRET = "your-secret-key-for-signing-tokens"
TEST_USER_ID = "test_user"
def create_auth_header(entity_type: str = "developer", permissions: list = None):
"""Create auth header with test token"""
token = jwt.encode(
{
"type": entity_type,
"entity_id": "test_user",
"permissions": permissions or ["read", "write"],
"exp": datetime.now(UTC) + timedelta(days=1)
},
# TODO: Use settings.JWT_SECRET_KEY
"your-secret-key-for-signing-tokens"
)
@pytest.fixture(scope="session", autouse=True)
def setup_test_environment():
"""Setup test environment and create test files"""
# Create test data directory if it doesn't exist
TEST_DATA_DIR.mkdir(exist_ok=True)
# Create a test text file
text_file = TEST_DATA_DIR / "test.txt"
if not text_file.exists():
text_file.write_text("This is a test document for DataBridge testing.")
# Create a small test PDF if it doesn't exist
pdf_file = TEST_DATA_DIR / "test.pdf"
if not pdf_file.exists():
# Create a minimal PDF for testing
try:
from reportlab.pdfgen import canvas
c = canvas.Canvas(str(pdf_file))
c.drawString(100, 750, "Test PDF Document")
c.save()
except ImportError:
pytest.skip("reportlab not installed, skipping PDF tests")
def create_test_token(
entity_type: str = "developer",
entity_id: str = TEST_USER_ID,
permissions: list = None,
app_id: str = None,
expired: bool = False
) -> str:
"""Create a test JWT token"""
if permissions is None:
permissions = ["read", "write", "admin"]
payload = {
"type": entity_type,
"entity_id": entity_id,
"permissions": permissions,
"exp": datetime.now(UTC) + timedelta(days=-1 if expired else 1)
}
if app_id:
payload["app_id"] = app_id
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
def create_auth_header(
entity_type: str = "developer",
permissions: list = None,
expired: bool = False
) -> Dict[str, str]:
"""Create authorization header with test token"""
token = create_test_token(entity_type, permissions=permissions, expired=expired)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def client():
return TestClient(app)
async def test_app() -> FastAPI:
"""Create test FastAPI application"""
# Configure test settings
settings = get_settings()
settings.JWT_SECRET_KEY = JWT_SECRET
return app
def test_ingest_document(client):
"""Test document ingestion endpoint"""
@pytest.fixture
async def client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client"""
async with AsyncClient(app=test_app, base_url="http://test") as client:
yield client
@pytest.mark.asyncio
async def test_ingest_text_document(client: AsyncClient):
"""Test ingesting a text document"""
headers = create_auth_header()
# Test text document
response = client.post(
"/documents",
response = await client.post(
"/documents/text",
json={
"content": "Test content",
"content_type": "text/plain",
"metadata": {"test": True}
"content": "Test content for document ingestion",
"metadata": {"test": True, "type": "text"}
},
headers=headers
)
print(response.json())
assert response.status_code == 200
data = response.json()
assert data["external_id"]
assert "external_id" in data
assert data["content_type"] == "text/plain"
assert data["metadata"]["test"] is True
# Test binary document
with open("test.pdf", "rb") as f:
content = base64.b64encode(f.read()).decode()
return data["external_id"]
response = client.post(
"/documents",
@pytest.mark.asyncio
async def test_ingest_file_document(client: AsyncClient):
"""Test ingesting a file (PDF) document"""
headers = create_auth_header()
pdf_path = TEST_DATA_DIR / "test.pdf"
if not pdf_path.exists():
pytest.skip("Test PDF file not available")
# Create form data with file and metadata
files = {
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
}
data = {
"metadata": json.dumps({"test": True, "type": "pdf"})
}
response = await client.post(
"/documents/file",
files=files,
data=data,
headers=headers
)
assert response.status_code == 200
data = response.json()
assert "external_id" in data
assert data["content_type"] == "application/pdf"
assert "storage_info" in data
return data["external_id"]
@pytest.mark.asyncio
async def test_ingest_error_handling(client: AsyncClient):
"""Test ingestion error cases"""
headers = create_auth_header()
# Test invalid text request
response = await client.post(
"/documents/text",
json={
"content": content,
"content_type": "application/pdf",
"filename": "test.pdf"
"wrong_field": "Test content" # Missing required content field
},
headers=headers
)
assert response.status_code == 200
assert response.status_code == 422 # Validation error
# Test invalid file request
response = await client.post(
"/documents/file",
files={}, # Missing file
data={"metadata": "{}"},
headers=headers
)
assert response.status_code == 422 # Validation error
# Test invalid metadata JSON
pdf_path = TEST_DATA_DIR / "test.pdf"
if pdf_path.exists():
files = {
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
}
response = await client.post(
"/documents/file",
files=files,
data={"metadata": "invalid json"},
headers=headers
)
assert response.status_code == 400 # Bad request
# Test oversized content
large_content = "x" * (10 * 1024 * 1024) # 10MB
response = await client.post(
"/documents/text",
json={
"content": large_content,
"metadata": {}
},
headers=headers
)
assert response.status_code == 400 # Bad request
def test_query_endpoints(client):
"""Test query endpoints"""
@pytest.mark.asyncio
async def test_auth_errors(client: AsyncClient):
"""Test authentication error cases"""
# Test missing auth header
response = await client.post("/documents/text")
assert response.status_code == 401
# Test invalid token
headers = {"Authorization": "Bearer invalid_token"}
response = await client.post("/documents/file", headers=headers)
assert response.status_code == 401
# Test expired token
headers = create_auth_header(expired=True)
response = await client.post("/documents/text", headers=headers)
assert response.status_code == 401
# Test insufficient permissions
headers = create_auth_header(permissions=["read"])
response = await client.post(
"/documents/text",
json={
"content": "Test content",
"metadata": {}
},
headers=headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_query_chunks(client: AsyncClient):
"""Test querying document chunks"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client)
headers = create_auth_header()
# Test chunk query
response = client.post(
response = await client.post(
"/query",
json={
"query": "test",
"query": "test document",
"return_type": "chunks",
"k": 2
},
headers=headers
)
assert response.status_code == 200
results = response.json()
assert len(results) <= 2
assert len(results) > 0
assert all(isinstance(r["score"], (int, float)) for r in results)
assert all(r["document_id"] == doc_id for r in results)
# Test document query
response = client.post(
@pytest.mark.asyncio
async def test_query_documents(client: AsyncClient):
"""Test querying for full documents"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client)
headers = create_auth_header()
response = await client.post(
"/query",
json={
"query": "test",
"return_type": "documents"
"query": "test document",
"return_type": "documents",
"filters": {"test": True}
},
headers=headers
)
assert response.status_code == 200
results = response.json()
assert len(results) > 0
assert results[0]["document_id"] == doc_id
assert "score" in results[0]
assert "metadata" in results[0]
def test_document_management(client):
"""Test document management endpoints"""
@pytest.mark.asyncio
async def test_list_documents(client: AsyncClient):
"""Test listing documents"""
# First ingest some documents
doc_id1 = await test_ingest_text_document(client)
doc_id2 = await test_ingest_text_document(client)
headers = create_auth_header()
# List documents
response = client.get("/documents", headers=headers)
response = await client.get("/documents", headers=headers)
assert response.status_code == 200
docs = response.json()
if docs:
# Get specific document
doc_id = docs[0]["external_id"]
response = client.get(f"/documents/{doc_id}", headers=headers)
assert response.status_code == 200
assert len(docs) >= 2
doc_ids = [doc["external_id"] for doc in docs]
assert doc_id1 in doc_ids
assert doc_id2 in doc_ids
def test_auth_errors(client):
"""Test authentication error cases"""
# No auth header
response = client.get("/documents")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_document(client: AsyncClient):
"""Test getting a specific document"""
# First ingest a document
doc_id = await test_ingest_text_document(client)
headers = create_auth_header()
response = await client.get(f"/documents/{doc_id}", headers=headers)
assert response.status_code == 200
doc = response.json()
assert doc["external_id"] == doc_id
assert "metadata" in doc
assert "content_type" in doc
# Invalid token
headers = {"Authorization": "Bearer invalid"}
response = client.get("/documents", headers=headers)
assert response.status_code == 401
# Expired token
token = jwt.encode(
{
"type": "developer",
"entity_id": "test",
"exp": datetime.now(UTC) - timedelta(days=1)
@pytest.mark.asyncio
async def test_error_handling(client: AsyncClient):
"""Test error handling scenarios"""
headers = create_auth_header()
# Test invalid document ID
response = await client.get("/documents/invalid_id", headers=headers)
assert response.status_code == 404
# Test invalid query parameters
response = await client.post(
"/query",
json={
"query": "", # Empty query
"k": -1 # Invalid k
},
"test-secret"
headers=headers
)
headers = {"Authorization": f"Bearer {token}"}
response = client.get("/documents", headers=headers)
assert response.status_code == 401
def main():
"""Run API endpoint tests"""
client = TestClient(app)
try:
test_ingest_document(client)
print("✓ Document ingestion tests passed")
test_query_endpoints(client)
print("✓ Query endpoint tests passed")
test_document_management(client)
print("✓ Document management tests passed")
test_auth_errors(client)
print("✓ Auth error tests passed")
except Exception as e:
print(f"× API test failed: {str(e)}")
raise
if __name__ == "__main__":
main()
assert response.status_code == 400
# Test oversized content
large_content = "x" * (10 * 1024 * 1024) # 10MB
response = await client.post(
"/documents",
json={
"content": large_content,
"content_type": "text/plain"
},
headers=headers
)
assert response.status_code == 400

View File

@ -136,17 +136,17 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"""Build MongoDB filter for access control."""
base_filter = {
"$or": [
{"metadata.owner.id": auth.entity_id},
{"metadata.readers": auth.entity_id},
{"metadata.writers": auth.entity_id},
{"metadata.admins": auth.entity_id}
{"owner.id": auth.entity_id},
{"access_control.readers": auth.entity_id},
{"access_control.writers": auth.entity_id},
{"access_control.admins": auth.entity_id}
]
}
if auth.entity_type == EntityType.DEVELOPER and auth.app_id:
# Add app-specific access for developers
base_filter["$or"].append(
{"metadata.app_access": auth.app_id}
{"access_control.app_access": auth.app_id}
)
return base_filter