mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
separate text and doc ingestion pathways
This commit is contained in:
parent
1f6551d464
commit
983a4ee854
70
core/api.py
70
core/api.py
@ -1,26 +1,35 @@
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Union, Dict, Set
|
||||
from fastapi import FastAPI, HTTPException, Depends, Header, APIRouter
|
||||
from fastapi import (
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Depends,
|
||||
Header,
|
||||
APIRouter,
|
||||
UploadFile
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
|
||||
from core.models.request import IngestRequest, QueryRequest
|
||||
|
||||
from .models.documents import (
|
||||
from core.models.request import IngestTextRequest, QueryRequest
|
||||
from core.models.documents import (
|
||||
Document,
|
||||
DocumentResult,
|
||||
ChunkResult,
|
||||
EntityType
|
||||
)
|
||||
from .models.auth import AuthContext
|
||||
from .services.document_service import DocumentService
|
||||
from .config import get_settings
|
||||
from .database.mongo_database import MongoDatabase
|
||||
from .vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
from .storage.s3_storage import S3Storage
|
||||
from .parser.unstructured_parser import UnstructuredAPIParser
|
||||
from .embedding_model.openai_embedding_model import OpenAIEmbeddingModel
|
||||
from .services.uri_service import get_uri_service
|
||||
from core.models.auth import AuthContext
|
||||
from core.services.document_service import DocumentService
|
||||
from core.config import get_settings
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
from core.storage.s3_storage import S3Storage
|
||||
from core.parser.unstructured_parser import UnstructuredAPIParser
|
||||
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
|
||||
from core.services.uri_service import get_uri_service
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
@ -75,8 +84,14 @@ document_service = DocumentService(
|
||||
)
|
||||
|
||||
|
||||
async def verify_token(authorization: str = Header(...)) -> AuthContext:
|
||||
async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
||||
"""Verify JWT Bearer token."""
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization header",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
try:
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
@ -104,15 +119,30 @@ async def verify_token(authorization: str = Header(...)) -> AuthContext:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
# API endpoints
|
||||
@app.post("/documents", response_model=Document)
|
||||
async def ingest_document(
|
||||
request: IngestRequest,
|
||||
@app.post("/documents/text", response_model=Document)
|
||||
async def ingest_text(
|
||||
request: IngestTextRequest,
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
) -> Document:
|
||||
"""Ingest a new document."""
|
||||
"""Ingest a text document."""
|
||||
try:
|
||||
return await document_service.ingest_document(request, auth)
|
||||
return await document_service.ingest_text(request, auth)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/documents/file", response_model=Document)
|
||||
async def ingest_file(
|
||||
file: UploadFile = File(...),
|
||||
metadata: str = Form("{}"), # JSON string of metadata
|
||||
auth: AuthContext = Depends(verify_token)
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
try:
|
||||
metadata_dict = json.loads(metadata)
|
||||
return await document_service.ingest_file(file, metadata_dict, auth)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(400, "Invalid metadata JSON")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import ReturnDocument
|
||||
@ -20,7 +20,7 @@ class MongoDatabase(BaseDatabase):
|
||||
self,
|
||||
uri: str,
|
||||
db_name: str,
|
||||
collection_name: str = "documents"
|
||||
collection_name: str,
|
||||
):
|
||||
"""Initialize MongoDB connection for document storage."""
|
||||
self.client = AsyncIOMotorClient(uri)
|
||||
@ -32,7 +32,7 @@ class MongoDatabase(BaseDatabase):
|
||||
try:
|
||||
# Create indexes for common queries
|
||||
await self.collection.create_index("external_id", unique=True)
|
||||
await self.collection.create_index("access_control.owner.id")
|
||||
await self.collection.create_index("owner.id")
|
||||
await self.collection.create_index("access_control.readers")
|
||||
await self.collection.create_index("access_control.writers")
|
||||
await self.collection.create_index("access_control.admins")
|
||||
@ -50,8 +50,8 @@ class MongoDatabase(BaseDatabase):
|
||||
doc_dict = document.model_dump()
|
||||
|
||||
# Ensure system metadata
|
||||
doc_dict["system_metadata"]["created_at"] = datetime.utcnow()
|
||||
doc_dict["system_metadata"]["updated_at"] = datetime.utcnow()
|
||||
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
|
||||
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
|
||||
|
||||
result = await self.collection.insert_one(doc_dict)
|
||||
return bool(result.inserted_id)
|
||||
@ -190,9 +190,8 @@ class MongoDatabase(BaseDatabase):
|
||||
access_control = doc.get("access_control", {})
|
||||
|
||||
# Check owner access
|
||||
owner = access_control.get("owner", {})
|
||||
if (owner.get("type") == auth.entity_type and
|
||||
owner.get("id") == auth.entity_id):
|
||||
owner = doc.get("owner", {})
|
||||
if (owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id):
|
||||
return True
|
||||
|
||||
# Check permission-specific access
|
||||
@ -216,7 +215,7 @@ class MongoDatabase(BaseDatabase):
|
||||
"""Build MongoDB filter for access control."""
|
||||
base_filter = {
|
||||
"$or": [
|
||||
{"access_control.owner.id": auth.entity_id},
|
||||
{"owner.id": auth.entity_id},
|
||||
{"access_control.readers": auth.entity_id},
|
||||
{"access_control.writers": auth.entity_id},
|
||||
{"access_control.admins": auth.entity_id}
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Dict, Any, List, Optional, Literal
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import uuid
|
||||
|
||||
@ -18,23 +18,23 @@ class QueryReturnType(str, Enum):
|
||||
class Document(BaseModel):
|
||||
"""Represents a document stored in MongoDB documents collection"""
|
||||
external_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
owner: Dict[str, str]
|
||||
content_type: str
|
||||
filename: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
storage_info: Dict[str, str] = Field(default_factory=dict) # s3_bucket, s3_key
|
||||
storage_info: Dict[str, str] = Field(default_factory=dict)
|
||||
system_metadata: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow(),
|
||||
"created_at": datetime.now(UTC),
|
||||
"updated_at": datetime.now(UTC),
|
||||
"version": 1
|
||||
}
|
||||
)
|
||||
access_control: Dict[str, Any] = Field(
|
||||
access_control: Dict[str, List[str]] = Field(
|
||||
default_factory=lambda: {
|
||||
"owner": None,
|
||||
"readers": set(),
|
||||
"writers": set(),
|
||||
"admins": set()
|
||||
"readers": [],
|
||||
"writers": [],
|
||||
"admins": []
|
||||
}
|
||||
)
|
||||
chunk_ids: List[str] = Field(default_factory=list)
|
||||
|
@ -1,18 +1,16 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .documents import QueryReturnType
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
class IngestTextRequest(BaseModel):
|
||||
"""Request model for text ingestion"""
|
||||
content: str
|
||||
# TODO: We should infer this, not request it
|
||||
content_type: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""Query request model - remains unchanged"""
|
||||
query: str
|
||||
return_type: QueryReturnType = QueryReturnType.CHUNKS
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
|
@ -1,9 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from typing import List, Union
|
||||
from fastapi import UploadFile
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""Base class for document parsing"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, content: str, metadata: Dict[str, Any]) -> List[str]:
|
||||
"""Parse content into chunks"""
|
||||
async def split_text(self, text: str) -> List[str]:
|
||||
"""Split plain text into chunks"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
|
||||
"""Parse file content into text chunks"""
|
||||
pass
|
||||
|
@ -1,13 +1,11 @@
|
||||
from typing import Dict, Any, List
|
||||
from .base_parser import BaseParser
|
||||
from unstructured.partition.auto import partition
|
||||
from typing import List, Union
|
||||
from fastapi import UploadFile
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
import nltk
|
||||
from unstructured.partition.auto import partition
|
||||
import logging
|
||||
|
||||
from .base_parser import BaseParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -17,20 +15,8 @@ class UnstructuredAPIParser(BaseParser):
|
||||
api_key: str,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
api_url: str = "https://api.unstructuredapp.io"
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
logger.info("Downloading NLTK punkt tokenizer...")
|
||||
nltk.download('punkt', quiet=True)
|
||||
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
@ -38,63 +24,38 @@ class UnstructuredAPIParser(BaseParser):
|
||||
separators=["\n\n", "\n", ". ", " ", ""]
|
||||
)
|
||||
|
||||
def parse(self, content: str, metadata: Dict[str, Any]) -> List[str]:
|
||||
"""Parse content using Unstructured API and split into chunks."""
|
||||
async def split_text(self, text: str) -> List[str]:
|
||||
"""Split plain text into chunks"""
|
||||
try:
|
||||
# Create temporary file for content
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=self._get_file_extension(metadata)) as temp_file:
|
||||
if metadata.get("is_base64", False):
|
||||
try:
|
||||
decoded_content = base64.b64decode(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode base64 content: {str(e)}")
|
||||
# If base64 decode fails, try treating as plain text
|
||||
decoded_content = content.encode('utf-8')
|
||||
else:
|
||||
decoded_content = content.encode('utf-8')
|
||||
|
||||
temp_file.write(decoded_content)
|
||||
temp_file_path = temp_file.name
|
||||
return self.text_splitter.split_text(text)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to split text: {str(e)}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# Use Unstructured API for parsing
|
||||
elements = partition(
|
||||
filename=temp_file_path,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
partition_via_api=True
|
||||
)
|
||||
async def parse_file(self, file: Union[UploadFile, bytes], content_type: str) -> List[str]:
|
||||
"""Parse file content using unstructured"""
|
||||
try:
|
||||
# Handle different file input types
|
||||
if isinstance(file, UploadFile):
|
||||
file_content = await file.read()
|
||||
else:
|
||||
file_content = file
|
||||
|
||||
# Combine elements and split into chunks
|
||||
full_text = "\n\n".join(str(element) for element in elements)
|
||||
chunks = self.text_splitter.split_text(full_text)
|
||||
# Parse with unstructured
|
||||
elements = partition(
|
||||
file=file_content,
|
||||
content_type=content_type,
|
||||
api_key=self.api_key
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
# If no chunks were created, use the full text as a single chunk
|
||||
logger.warning("No chunks created, using full text as single chunk")
|
||||
return [full_text] if full_text.strip() else []
|
||||
# Extract text from elements
|
||||
chunks = []
|
||||
for element in elements:
|
||||
if hasattr(element, 'text') and element.text:
|
||||
chunks.append(element.text.strip())
|
||||
|
||||
return chunks
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete temporary file: {str(e)}")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing document: {str(e)}")
|
||||
raise Exception(f"Error parsing document: {str(e)}")
|
||||
|
||||
def _get_file_extension(self, metadata: Dict[str, Any]) -> str:
|
||||
"""Get appropriate file extension based on content type."""
|
||||
content_type_mapping = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
'text/plain': '.txt',
|
||||
'text/html': '.html'
|
||||
}
|
||||
return content_type_mapping.get(metadata.get('content_type'), '.txt')
|
||||
logger.error(f"Failed to parse file: {str(e)}")
|
||||
raise
|
||||
|
@ -1,15 +1,21 @@
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
import logging
|
||||
from core.api import IngestRequest, QueryRequest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from core.database.base_database import BaseDatabase
|
||||
from core.embedding_model.base_embedding_model import BaseEmbeddingModel
|
||||
from core.models.request import IngestTextRequest, QueryRequest
|
||||
from core.parser.base_parser import BaseParser
|
||||
from core.storage.base_storage import BaseStorage
|
||||
from core.vector_store.base_vector_store import BaseVectorStore
|
||||
from ..models.documents import (
|
||||
Document, DocumentChunk, ChunkResult, DocumentContent, DocumentResult,
|
||||
Document,
|
||||
DocumentChunk,
|
||||
ChunkResult,
|
||||
DocumentContent,
|
||||
DocumentResult,
|
||||
QueryReturnType
|
||||
)
|
||||
from ..models.auth import AuthContext
|
||||
@ -33,73 +39,128 @@ class DocumentService:
|
||||
self.parser = parser
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
async def ingest_document(
|
||||
async def ingest_text(
|
||||
self,
|
||||
request: IngestRequest,
|
||||
request: IngestTextRequest,
|
||||
auth: AuthContext
|
||||
) -> Document:
|
||||
"""Ingest a new document with chunks."""
|
||||
"""Ingest a text document."""
|
||||
try:
|
||||
# 1. Create document record
|
||||
doc = Document(
|
||||
content_type=request.content_type,
|
||||
filename=request.filename,
|
||||
content_type="text/plain",
|
||||
metadata=request.metadata,
|
||||
owner={
|
||||
"type": auth.entity_type,
|
||||
"id": auth.entity_id
|
||||
},
|
||||
access_control={
|
||||
"owner": {
|
||||
"type": auth.entity_type,
|
||||
"id": auth.entity_id
|
||||
},
|
||||
"readers": {auth.entity_id},
|
||||
"readers": [auth.entity_id],
|
||||
"writers": {auth.entity_id},
|
||||
"admins": {auth.entity_id}
|
||||
}
|
||||
)
|
||||
logger.info(f"Created text document record with ID {doc.external_id}")
|
||||
|
||||
# 2. Store file in storage if it's not text
|
||||
if request.content_type != "text/plain":
|
||||
storage_info = await self.storage.upload_from_base64(
|
||||
request.content,
|
||||
doc.external_id,
|
||||
request.content_type
|
||||
)
|
||||
doc.storage_info = {
|
||||
"bucket": storage_info[0],
|
||||
"key": storage_info[1]
|
||||
}
|
||||
# 2. Parse content into chunks
|
||||
chunks = await self.parser.split_text(request.content)
|
||||
if not chunks:
|
||||
raise ValueError("No content chunks extracted from text")
|
||||
logger.info(f"Split text into {len(chunks)} chunks")
|
||||
|
||||
# 3. Parse content into chunks
|
||||
chunks = await self.parser.parse(request.content)
|
||||
|
||||
# 4. Generate embeddings for chunks
|
||||
# 3. Generate embeddings for chunks
|
||||
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
|
||||
logger.info(f"Generated {len(embeddings)} embeddings")
|
||||
|
||||
# 5. Create and store chunks with embeddings
|
||||
chunk_objects = []
|
||||
for i, (content, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
chunk = DocumentChunk(
|
||||
document_id=doc.external_id,
|
||||
content=content,
|
||||
embedding=embedding,
|
||||
chunk_number=i,
|
||||
metadata=doc.metadata # Inherit document metadata
|
||||
)
|
||||
chunk_objects.append(chunk)
|
||||
# 4. Create and store chunk objects
|
||||
chunk_objects = self._create_chunk_objects(
|
||||
doc.external_id,
|
||||
chunks,
|
||||
embeddings,
|
||||
doc.metadata
|
||||
)
|
||||
logger.info(f"Created {len(chunk_objects)} chunk objects")
|
||||
|
||||
# 6. Store chunks in vector store
|
||||
success = await self.vector_store.store_embeddings(chunk_objects)
|
||||
if not success:
|
||||
raise Exception("Failed to store chunk embeddings")
|
||||
|
||||
# 7. Store document metadata
|
||||
if not await self.db.store_document(doc):
|
||||
raise Exception("Failed to store document metadata")
|
||||
# 5. Store everything
|
||||
await self._store_chunks_and_doc(chunk_objects, doc)
|
||||
logger.info(f"Successfully stored text document {doc.external_id}")
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text document ingestion failed: {str(e)}")
|
||||
# TODO: Clean up any stored data on failure
|
||||
raise Exception(f"Document ingestion failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def ingest_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
metadata: Dict[str, Any],
|
||||
auth: AuthContext
|
||||
) -> Document:
|
||||
"""Ingest a file document."""
|
||||
try:
|
||||
# 1. Create document record
|
||||
doc = Document(
|
||||
content_type=file.content_type,
|
||||
filename=file.filename,
|
||||
metadata=metadata,
|
||||
owner={
|
||||
"type": auth.entity_type,
|
||||
"id": auth.entity_id
|
||||
},
|
||||
access_control={
|
||||
"readers": [auth.entity_id],
|
||||
"writers": {auth.entity_id},
|
||||
"admins": {auth.entity_id}
|
||||
}
|
||||
)
|
||||
logger.info(f"Created file document record with ID {doc.external_id}")
|
||||
|
||||
# 2. Read and store file
|
||||
file_content = await file.read()
|
||||
storage_info = await self.storage.upload_from_base64(
|
||||
base64.b64encode(file_content).decode(),
|
||||
doc.external_id,
|
||||
file.content_type
|
||||
)
|
||||
doc.storage_info = {
|
||||
"bucket": storage_info[0],
|
||||
"key": storage_info[1]
|
||||
}
|
||||
logger.info(
|
||||
f"Stored file in bucket {storage_info[0]} with key {storage_info[1]}"
|
||||
)
|
||||
|
||||
# 3. Parse content into chunks
|
||||
chunks = await self.parser.parse_file(file_content, file.content_type)
|
||||
if not chunks:
|
||||
raise ValueError("No content chunks extracted from file")
|
||||
logger.info(f"Parsed file into {len(chunks)} chunks")
|
||||
|
||||
# 4. Generate embeddings for chunks
|
||||
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
|
||||
logger.info(f"Generated {len(embeddings)} embeddings")
|
||||
|
||||
# 5. Create and store chunk objects
|
||||
chunk_objects = self._create_chunk_objects(
|
||||
doc.external_id,
|
||||
chunks,
|
||||
embeddings,
|
||||
doc.metadata
|
||||
)
|
||||
logger.info(f"Created {len(chunk_objects)} chunk objects")
|
||||
|
||||
# 6. Store everything
|
||||
await self._store_chunks_and_doc(chunk_objects, doc)
|
||||
logger.info(f"Successfully stored file document {doc.external_id}")
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"File document ingestion failed: {str(e)}")
|
||||
# TODO: Clean up any stored data on failure
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
@ -109,12 +170,17 @@ class DocumentService:
|
||||
"""Query documents with specified return type."""
|
||||
try:
|
||||
# 1. Get embedding for query
|
||||
query_embedding = await self.embedding_model.embed_for_query(request.query)
|
||||
query_embedding = await self.embedding_model.embed_for_query(
|
||||
request.query
|
||||
)
|
||||
logger.info("Generated query embedding")
|
||||
|
||||
# 2. Find authorized documents
|
||||
doc_ids = await self.db.find_documents(auth, request.filters)
|
||||
if not doc_ids:
|
||||
logger.info("No authorized documents found")
|
||||
return []
|
||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||
|
||||
# 3. Search chunks with vector similarity
|
||||
chunks = await self.vector_store.query_similar(
|
||||
@ -123,25 +189,71 @@ class DocumentService:
|
||||
auth=auth,
|
||||
filters={"document_id": {"$in": doc_ids}}
|
||||
)
|
||||
logger.info(f"Found {len(chunks)} similar chunks")
|
||||
|
||||
# 4. Return results in requested format
|
||||
if request.return_type == QueryReturnType.CHUNKS:
|
||||
return await self._create_chunk_results(auth, chunks)
|
||||
results = await self._create_chunk_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} chunk results")
|
||||
return results
|
||||
else:
|
||||
return await self._create_document_results(auth, chunks)
|
||||
results = await self._create_document_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} document results")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {str(e)}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
async def _create_chunk_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[ChunkResult]:
|
||||
def _create_chunk_objects(
|
||||
self,
|
||||
doc_id: str,
|
||||
chunks: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadata: Dict[str, Any]
|
||||
) -> List[DocumentChunk]:
|
||||
"""Helper to create chunk objects"""
|
||||
return [
|
||||
DocumentChunk(
|
||||
document_id=doc_id,
|
||||
content=content,
|
||||
embedding=embedding,
|
||||
chunk_number=i,
|
||||
metadata=metadata
|
||||
)
|
||||
for i, (content, embedding) in enumerate(zip(chunks, embeddings))
|
||||
]
|
||||
|
||||
async def _store_chunks_and_doc(
|
||||
self,
|
||||
chunk_objects: List[DocumentChunk],
|
||||
doc: Document
|
||||
) -> None:
|
||||
"""Helper to store chunks and document"""
|
||||
# Store chunks in vector store
|
||||
if not await self.vector_store.store_embeddings(chunk_objects):
|
||||
raise Exception("Failed to store chunk embeddings")
|
||||
logger.debug("Stored chunk embeddings in vector store")
|
||||
|
||||
# Store document metadata
|
||||
if not await self.db.store_document(doc):
|
||||
raise Exception("Failed to store document metadata")
|
||||
logger.debug("Stored document metadata in database")
|
||||
|
||||
async def _create_chunk_results(
|
||||
self,
|
||||
auth: AuthContext,
|
||||
chunks: List[DocumentChunk]
|
||||
) -> List[ChunkResult]:
|
||||
"""Create ChunkResult objects with document metadata."""
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
# Get document metadata
|
||||
doc = await self.db.get_document(chunk.document_id, auth)
|
||||
if not doc:
|
||||
logger.warning(f"Document {chunk.document_id} not found")
|
||||
continue
|
||||
logger.debug(f"Retrieved metadata for document {chunk.document_id}")
|
||||
|
||||
# Generate download URL if needed
|
||||
download_url = None
|
||||
@ -150,6 +262,9 @@ class DocumentService:
|
||||
doc.storage_info["bucket"],
|
||||
doc.storage_info["key"]
|
||||
)
|
||||
logger.debug(
|
||||
f"Generated download URL for document {chunk.document_id}"
|
||||
)
|
||||
|
||||
results.append(ChunkResult(
|
||||
content=chunk.content,
|
||||
@ -162,22 +277,31 @@ class DocumentService:
|
||||
download_url=download_url
|
||||
))
|
||||
|
||||
logger.info(f"Created {len(results)} chunk results")
|
||||
return results
|
||||
|
||||
async def _create_document_results(self, auth: AuthContext, chunks: List[DocumentChunk]) -> List[DocumentResult]:
|
||||
async def _create_document_results(
|
||||
self,
|
||||
auth: AuthContext,
|
||||
chunks: List[DocumentChunk]
|
||||
) -> List[DocumentResult]:
|
||||
"""Group chunks by document and create DocumentResult objects."""
|
||||
# Group chunks by document and get highest scoring chunk per doc
|
||||
doc_chunks: Dict[str, DocumentChunk] = {}
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in doc_chunks or chunk.score > doc_chunks[chunk.document_id].score:
|
||||
if (chunk.document_id not in doc_chunks or
|
||||
chunk.score > doc_chunks[chunk.document_id].score):
|
||||
doc_chunks[chunk.document_id] = chunk
|
||||
logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
|
||||
|
||||
results = []
|
||||
for doc_id, chunk in doc_chunks.items():
|
||||
# Get document metadata
|
||||
doc = await self.db.get_document(doc_id, auth)
|
||||
if not doc:
|
||||
logger.warning(f"Document {doc_id} not found")
|
||||
continue
|
||||
logger.debug(f"Retrieved metadata for document {doc_id}")
|
||||
|
||||
# Create DocumentContent based on content type
|
||||
if doc.content_type == "text/plain":
|
||||
@ -186,6 +310,7 @@ class DocumentService:
|
||||
value=chunk.content,
|
||||
filename=None
|
||||
)
|
||||
logger.debug(f"Created text content for document {doc_id}")
|
||||
else:
|
||||
# Generate download URL for file types
|
||||
download_url = await self.storage.get_download_url(
|
||||
@ -197,6 +322,7 @@ class DocumentService:
|
||||
value=download_url,
|
||||
filename=doc.filename
|
||||
)
|
||||
logger.debug(f"Created URL content for document {doc_id}")
|
||||
|
||||
results.append(DocumentResult(
|
||||
score=chunk.score,
|
||||
@ -205,4 +331,5 @@ class DocumentService:
|
||||
content=content
|
||||
))
|
||||
|
||||
logger.info(f"Created {len(results)} document results")
|
||||
return results
|
||||
|
@ -1,158 +1,344 @@
|
||||
import base64
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from core.api import app
|
||||
from pathlib import Path
|
||||
import jwt
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
from core.api import app, get_settings
|
||||
from core.models.auth import EntityType
|
||||
from core.database.mongo_database import MongoDatabase
|
||||
from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
|
||||
|
||||
# Test configuration
|
||||
TEST_DATA_DIR = Path(__file__).parent / "test_data"
|
||||
JWT_SECRET = "your-secret-key-for-signing-tokens"
|
||||
TEST_USER_ID = "test_user"
|
||||
|
||||
|
||||
def create_auth_header(entity_type: str = "developer", permissions: list = None):
|
||||
"""Create auth header with test token"""
|
||||
token = jwt.encode(
|
||||
{
|
||||
"type": entity_type,
|
||||
"entity_id": "test_user",
|
||||
"permissions": permissions or ["read", "write"],
|
||||
"exp": datetime.now(UTC) + timedelta(days=1)
|
||||
},
|
||||
# TODO: Use settings.JWT_SECRET_KEY
|
||||
"your-secret-key-for-signing-tokens"
|
||||
)
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_test_environment():
|
||||
"""Setup test environment and create test files"""
|
||||
# Create test data directory if it doesn't exist
|
||||
TEST_DATA_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Create a test text file
|
||||
text_file = TEST_DATA_DIR / "test.txt"
|
||||
if not text_file.exists():
|
||||
text_file.write_text("This is a test document for DataBridge testing.")
|
||||
|
||||
# Create a small test PDF if it doesn't exist
|
||||
pdf_file = TEST_DATA_DIR / "test.pdf"
|
||||
if not pdf_file.exists():
|
||||
# Create a minimal PDF for testing
|
||||
try:
|
||||
from reportlab.pdfgen import canvas
|
||||
c = canvas.Canvas(str(pdf_file))
|
||||
c.drawString(100, 750, "Test PDF Document")
|
||||
c.save()
|
||||
except ImportError:
|
||||
pytest.skip("reportlab not installed, skipping PDF tests")
|
||||
|
||||
|
||||
def create_test_token(
|
||||
entity_type: str = "developer",
|
||||
entity_id: str = TEST_USER_ID,
|
||||
permissions: list = None,
|
||||
app_id: str = None,
|
||||
expired: bool = False
|
||||
) -> str:
|
||||
"""Create a test JWT token"""
|
||||
if permissions is None:
|
||||
permissions = ["read", "write", "admin"]
|
||||
|
||||
payload = {
|
||||
"type": entity_type,
|
||||
"entity_id": entity_id,
|
||||
"permissions": permissions,
|
||||
"exp": datetime.now(UTC) + timedelta(days=-1 if expired else 1)
|
||||
}
|
||||
|
||||
if app_id:
|
||||
payload["app_id"] = app_id
|
||||
|
||||
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def create_auth_header(
|
||||
entity_type: str = "developer",
|
||||
permissions: list = None,
|
||||
expired: bool = False
|
||||
) -> Dict[str, str]:
|
||||
"""Create authorization header with test token"""
|
||||
token = create_test_token(entity_type, permissions=permissions, expired=expired)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
async def test_app() -> FastAPI:
|
||||
"""Create test FastAPI application"""
|
||||
# Configure test settings
|
||||
settings = get_settings()
|
||||
settings.JWT_SECRET_KEY = JWT_SECRET
|
||||
return app
|
||||
|
||||
|
||||
def test_ingest_document(client):
|
||||
"""Test document ingestion endpoint"""
|
||||
@pytest.fixture
|
||||
async def client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create async test client"""
|
||||
async with AsyncClient(app=test_app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_text_document(client: AsyncClient):
|
||||
"""Test ingesting a text document"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test text document
|
||||
response = client.post(
|
||||
"/documents",
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
json={
|
||||
"content": "Test content",
|
||||
"content_type": "text/plain",
|
||||
"metadata": {"test": True}
|
||||
"content": "Test content for document ingestion",
|
||||
"metadata": {"test": True, "type": "text"}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["external_id"]
|
||||
assert "external_id" in data
|
||||
assert data["content_type"] == "text/plain"
|
||||
assert data["metadata"]["test"] is True
|
||||
|
||||
# Test binary document
|
||||
with open("test.pdf", "rb") as f:
|
||||
content = base64.b64encode(f.read()).decode()
|
||||
return data["external_id"]
|
||||
|
||||
response = client.post(
|
||||
"/documents",
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_file_document(client: AsyncClient):
|
||||
"""Test ingesting a file (PDF) document"""
|
||||
headers = create_auth_header()
|
||||
pdf_path = TEST_DATA_DIR / "test.pdf"
|
||||
|
||||
if not pdf_path.exists():
|
||||
pytest.skip("Test PDF file not available")
|
||||
|
||||
# Create form data with file and metadata
|
||||
files = {
|
||||
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
|
||||
}
|
||||
data = {
|
||||
"metadata": json.dumps({"test": True, "type": "pdf"})
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "external_id" in data
|
||||
assert data["content_type"] == "application/pdf"
|
||||
assert "storage_info" in data
|
||||
|
||||
return data["external_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(client: AsyncClient):
|
||||
"""Test ingestion error cases"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid text request
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
json={
|
||||
"content": content,
|
||||
"content_type": "application/pdf",
|
||||
"filename": "test.pdf"
|
||||
"wrong_field": "Test content" # Missing required content field
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test invalid file request
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
files={}, # Missing file
|
||||
data={"metadata": "{}"},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
# Test invalid metadata JSON
|
||||
pdf_path = TEST_DATA_DIR / "test.pdf"
|
||||
if pdf_path.exists():
|
||||
files = {
|
||||
"file": ("test.pdf", open(pdf_path, "rb"), "application/pdf")
|
||||
}
|
||||
response = await client.post(
|
||||
"/documents/file",
|
||||
files=files,
|
||||
data={"metadata": "invalid json"},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400 # Bad request
|
||||
|
||||
# Test oversized content
|
||||
large_content = "x" * (10 * 1024 * 1024) # 10MB
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
json={
|
||||
"content": large_content,
|
||||
"metadata": {}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400 # Bad request
|
||||
|
||||
|
||||
def test_query_endpoints(client):
|
||||
"""Test query endpoints"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_errors(client: AsyncClient):
|
||||
"""Test authentication error cases"""
|
||||
# Test missing auth header
|
||||
response = await client.post("/documents/text")
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test invalid token
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.post("/documents/file", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test expired token
|
||||
headers = create_auth_header(expired=True)
|
||||
response = await client.post("/documents/text", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Test insufficient permissions
|
||||
headers = create_auth_header(permissions=["read"])
|
||||
response = await client.post(
|
||||
"/documents/text",
|
||||
json={
|
||||
"content": "Test content",
|
||||
"metadata": {}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks(client: AsyncClient):
|
||||
"""Test querying document chunks"""
|
||||
# First ingest a document to query
|
||||
doc_id = await test_ingest_text_document(client)
|
||||
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test chunk query
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "test",
|
||||
"query": "test document",
|
||||
"return_type": "chunks",
|
||||
"k": 2
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) <= 2
|
||||
assert len(results) > 0
|
||||
assert all(isinstance(r["score"], (int, float)) for r in results)
|
||||
assert all(r["document_id"] == doc_id for r in results)
|
||||
|
||||
# Test document query
|
||||
response = client.post(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(client: AsyncClient):
|
||||
"""Test querying for full documents"""
|
||||
# First ingest a document to query
|
||||
doc_id = await test_ingest_text_document(client)
|
||||
|
||||
headers = create_auth_header()
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "test",
|
||||
"return_type": "documents"
|
||||
"query": "test document",
|
||||
"return_type": "documents",
|
||||
"filters": {"test": True}
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert len(results) > 0
|
||||
assert results[0]["document_id"] == doc_id
|
||||
assert "score" in results[0]
|
||||
assert "metadata" in results[0]
|
||||
|
||||
|
||||
def test_document_management(client):
|
||||
"""Test document management endpoints"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_documents(client: AsyncClient):
|
||||
"""Test listing documents"""
|
||||
# First ingest some documents
|
||||
doc_id1 = await test_ingest_text_document(client)
|
||||
doc_id2 = await test_ingest_text_document(client)
|
||||
|
||||
headers = create_auth_header()
|
||||
|
||||
# List documents
|
||||
response = client.get("/documents", headers=headers)
|
||||
response = await client.get("/documents", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
docs = response.json()
|
||||
|
||||
if docs:
|
||||
# Get specific document
|
||||
doc_id = docs[0]["external_id"]
|
||||
response = client.get(f"/documents/{doc_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert len(docs) >= 2
|
||||
doc_ids = [doc["external_id"] for doc in docs]
|
||||
assert doc_id1 in doc_ids
|
||||
assert doc_id2 in doc_ids
|
||||
|
||||
|
||||
def test_auth_errors(client):
|
||||
"""Test authentication error cases"""
|
||||
# No auth header
|
||||
response = client.get("/documents")
|
||||
assert response.status_code == 401
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_document(client: AsyncClient):
|
||||
"""Test getting a specific document"""
|
||||
# First ingest a document
|
||||
doc_id = await test_ingest_text_document(client)
|
||||
|
||||
headers = create_auth_header()
|
||||
response = await client.get(f"/documents/{doc_id}", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
doc = response.json()
|
||||
assert doc["external_id"] == doc_id
|
||||
assert "metadata" in doc
|
||||
assert "content_type" in doc
|
||||
|
||||
# Invalid token
|
||||
headers = {"Authorization": "Bearer invalid"}
|
||||
response = client.get("/documents", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Expired token
|
||||
token = jwt.encode(
|
||||
{
|
||||
"type": "developer",
|
||||
"entity_id": "test",
|
||||
"exp": datetime.now(UTC) - timedelta(days=1)
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling(client: AsyncClient):
|
||||
"""Test error handling scenarios"""
|
||||
headers = create_auth_header()
|
||||
|
||||
# Test invalid document ID
|
||||
response = await client.get("/documents/invalid_id", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid query parameters
|
||||
response = await client.post(
|
||||
"/query",
|
||||
json={
|
||||
"query": "", # Empty query
|
||||
"k": -1 # Invalid k
|
||||
},
|
||||
"test-secret"
|
||||
headers=headers
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = client.get("/documents", headers=headers)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def main():
|
||||
"""Run API endpoint tests"""
|
||||
client = TestClient(app)
|
||||
|
||||
try:
|
||||
test_ingest_document(client)
|
||||
print("✓ Document ingestion tests passed")
|
||||
|
||||
test_query_endpoints(client)
|
||||
print("✓ Query endpoint tests passed")
|
||||
|
||||
test_document_management(client)
|
||||
print("✓ Document management tests passed")
|
||||
|
||||
test_auth_errors(client)
|
||||
print("✓ Auth error tests passed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"× API test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test oversized content
|
||||
large_content = "x" * (10 * 1024 * 1024) # 10MB
|
||||
response = await client.post(
|
||||
"/documents",
|
||||
json={
|
||||
"content": large_content,
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
headers=headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
Binary file not shown.
@ -136,17 +136,17 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
||||
"""Build MongoDB filter for access control."""
|
||||
base_filter = {
|
||||
"$or": [
|
||||
{"metadata.owner.id": auth.entity_id},
|
||||
{"metadata.readers": auth.entity_id},
|
||||
{"metadata.writers": auth.entity_id},
|
||||
{"metadata.admins": auth.entity_id}
|
||||
{"owner.id": auth.entity_id},
|
||||
{"access_control.readers": auth.entity_id},
|
||||
{"access_control.writers": auth.entity_id},
|
||||
{"access_control.admins": auth.entity_id}
|
||||
]
|
||||
}
|
||||
|
||||
if auth.entity_type == EntityType.DEVELOPER and auth.app_id:
|
||||
# Add app-specific access for developers
|
||||
base_filter["$or"].append(
|
||||
{"metadata.app_access": auth.app_id}
|
||||
{"access_control.app_access": auth.app_id}
|
||||
)
|
||||
|
||||
return base_filter
|
||||
|
Loading…
x
Reference in New Issue
Block a user