mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add reranking (#14)
This commit is contained in:
parent
c936aa91a4
commit
20faae8903
@ -11,6 +11,7 @@ vector_store = "mongodb"
|
|||||||
embedding = "ollama" # "openai", "ollama"
|
embedding = "ollama" # "openai", "ollama"
|
||||||
completion = "ollama" # "openai", "ollama"
|
completion = "ollama" # "openai", "ollama"
|
||||||
parser = "combined" # "combined", "unstructured", "contextual"
|
parser = "combined" # "combined", "unstructured", "contextual"
|
||||||
|
reranker = "bge" # "bge"
|
||||||
|
|
||||||
# Storage Configuration
|
# Storage Configuration
|
||||||
[storage.local]
|
[storage.local]
|
||||||
@ -45,12 +46,20 @@ default_temperature = 0.7
|
|||||||
[models.ollama]
|
[models.ollama]
|
||||||
base_url = "http://localhost:11434"
|
base_url = "http://localhost:11434"
|
||||||
|
|
||||||
|
[models.reranker]
|
||||||
|
model_name = "BAAI/bge-reranker-large" # "BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-large"
|
||||||
|
device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU
|
||||||
|
use_fp16 = true
|
||||||
|
query_max_length = 256
|
||||||
|
passage_max_length = 512
|
||||||
|
|
||||||
# Document Processing
|
# Document Processing
|
||||||
[processing]
|
[processing]
|
||||||
[processing.text]
|
[processing.text]
|
||||||
chunk_size = 1000
|
chunk_size = 1000
|
||||||
chunk_overlap = 200
|
chunk_overlap = 200
|
||||||
default_k = 4
|
default_k = 4
|
||||||
|
use_reranking = true # Whether to use reranking by default
|
||||||
|
|
||||||
[processing.video]
|
[processing.video]
|
||||||
frame_sample_rate = 120
|
frame_sample_rate = 120
|
||||||
|
120
core/api.py
120
core/api.py
@ -28,6 +28,7 @@ from core.storage.local_storage import LocalStorage
|
|||||||
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
|
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
|
||||||
from core.completion.ollama_completion import OllamaCompletionModel
|
from core.completion.ollama_completion import OllamaCompletionModel
|
||||||
from core.parser.contextual_parser import ContextualParser
|
from core.parser.contextual_parser import ContextualParser
|
||||||
|
from core.reranker.bge_reranker import BGEReranker
|
||||||
|
|
||||||
# Initialize FastAPI app
|
# Initialize FastAPI app
|
||||||
app = FastAPI(title="DataBridge API")
|
app = FastAPI(title="DataBridge API")
|
||||||
@ -158,14 +159,28 @@ match settings.COMPLETION_PROVIDER:
|
|||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}")
|
raise ValueError(f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}")
|
||||||
|
|
||||||
|
# Initialize reranker
|
||||||
|
match settings.RERANKER_PROVIDER:
|
||||||
|
case "bge":
|
||||||
|
reranker = BGEReranker(
|
||||||
|
model_name=settings.RERANKER_MODEL,
|
||||||
|
device=settings.RERANKER_DEVICE,
|
||||||
|
use_fp16=settings.RERANKER_USE_FP16,
|
||||||
|
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
|
||||||
|
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
|
||||||
|
|
||||||
# Initialize document service with configured components
|
# Initialize document service with configured components
|
||||||
document_service = DocumentService(
|
document_service = DocumentService(
|
||||||
|
storage=storage,
|
||||||
database=database,
|
database=database,
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
storage=storage,
|
|
||||||
parser=parser,
|
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
completion_model=completion_model,
|
completion_model=completion_model,
|
||||||
|
parser=parser,
|
||||||
|
reranker=reranker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -243,27 +258,51 @@ async def ingest_file(
|
|||||||
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
||||||
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||||
"""Retrieve relevant chunks."""
|
"""Retrieve relevant chunks."""
|
||||||
async with telemetry.track_operation(
|
try:
|
||||||
operation_type="retrieve_chunks",
|
async with telemetry.track_operation(
|
||||||
user_id=auth.entity_id,
|
operation_type="retrieve_chunks",
|
||||||
metadata=request.model_dump(),
|
user_id=auth.entity_id,
|
||||||
):
|
metadata={
|
||||||
return await document_service.retrieve_chunks(
|
"k": request.k,
|
||||||
request.query, auth, request.filters, request.k, request.min_score
|
"min_score": request.min_score,
|
||||||
)
|
"use_reranking": request.use_reranking,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
return await document_service.retrieve_chunks(
|
||||||
|
request.query,
|
||||||
|
auth,
|
||||||
|
request.filters,
|
||||||
|
request.k,
|
||||||
|
request.min_score,
|
||||||
|
request.use_reranking,
|
||||||
|
)
|
||||||
|
except PermissionError as e:
|
||||||
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
||||||
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||||
"""Retrieve relevant documents."""
|
"""Retrieve relevant documents."""
|
||||||
async with telemetry.track_operation(
|
try:
|
||||||
operation_type="retrieve_docs",
|
async with telemetry.track_operation(
|
||||||
user_id=auth.entity_id,
|
operation_type="retrieve_docs",
|
||||||
metadata=request.model_dump(),
|
user_id=auth.entity_id,
|
||||||
):
|
metadata={
|
||||||
return await document_service.retrieve_docs(
|
"k": request.k,
|
||||||
request.query, auth, request.filters, request.k, request.min_score
|
"min_score": request.min_score,
|
||||||
)
|
"use_reranking": request.use_reranking,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
return await document_service.retrieve_docs(
|
||||||
|
request.query,
|
||||||
|
auth,
|
||||||
|
request.filters,
|
||||||
|
request.k,
|
||||||
|
request.min_score,
|
||||||
|
request.use_reranking,
|
||||||
|
)
|
||||||
|
except PermissionError as e:
|
||||||
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query", response_model=CompletionResponse)
|
@app.post("/query", response_model=CompletionResponse)
|
||||||
@ -271,27 +310,30 @@ async def query_completion(
|
|||||||
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
|
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
|
||||||
):
|
):
|
||||||
"""Generate completion using relevant chunks as context."""
|
"""Generate completion using relevant chunks as context."""
|
||||||
async with telemetry.track_operation(
|
try:
|
||||||
operation_type="query",
|
async with telemetry.track_operation(
|
||||||
user_id=auth.entity_id,
|
operation_type="query",
|
||||||
metadata=request.model_dump(),
|
user_id=auth.entity_id,
|
||||||
) as span:
|
metadata={
|
||||||
response = await document_service.query(
|
"k": request.k,
|
||||||
request.query,
|
"min_score": request.min_score,
|
||||||
auth,
|
"max_tokens": request.max_tokens,
|
||||||
request.filters,
|
"temperature": request.temperature,
|
||||||
request.k,
|
"use_reranking": request.use_reranking,
|
||||||
request.min_score,
|
},
|
||||||
request.max_tokens,
|
):
|
||||||
request.temperature,
|
return await document_service.query(
|
||||||
)
|
request.query,
|
||||||
if isinstance(response, dict) and "usage" in response:
|
auth,
|
||||||
usage = response["usage"]
|
request.filters,
|
||||||
if isinstance(usage, dict):
|
request.k,
|
||||||
span.set_attribute("tokens.completion", usage.get("completion_tokens", 0))
|
request.min_score,
|
||||||
span.set_attribute("tokens.prompt", usage.get("prompt_tokens", 0))
|
request.max_tokens,
|
||||||
span.set_attribute("tokens.total", usage.get("total_tokens", 0))
|
request.temperature,
|
||||||
return response
|
request.use_reranking,
|
||||||
|
)
|
||||||
|
except PermissionError as e:
|
||||||
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.get("/documents", response_model=List[Document])
|
@app.get("/documents", response_model=List[Document])
|
||||||
|
@ -32,6 +32,7 @@ class Settings(BaseSettings):
|
|||||||
EMBEDDING_PROVIDER: str = "openai"
|
EMBEDDING_PROVIDER: str = "openai"
|
||||||
COMPLETION_PROVIDER: str = "ollama"
|
COMPLETION_PROVIDER: str = "ollama"
|
||||||
PARSER_PROVIDER: str = "combined"
|
PARSER_PROVIDER: str = "combined"
|
||||||
|
RERANKER_PROVIDER: str = "bge"
|
||||||
|
|
||||||
# Storage settings
|
# Storage settings
|
||||||
STORAGE_PATH: str = "./storage"
|
STORAGE_PATH: str = "./storage"
|
||||||
@ -53,6 +54,11 @@ class Settings(BaseSettings):
|
|||||||
COMPLETION_MAX_TOKENS: int = 1000
|
COMPLETION_MAX_TOKENS: int = 1000
|
||||||
COMPLETION_TEMPERATURE: float = 0.7
|
COMPLETION_TEMPERATURE: float = 0.7
|
||||||
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
||||||
|
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-gemma"
|
||||||
|
RERANKER_DEVICE: Optional[str] = None
|
||||||
|
RERANKER_USE_FP16: bool = True
|
||||||
|
RERANKER_QUERY_MAX_LENGTH: int = 256
|
||||||
|
RERANKER_PASSAGE_MAX_LENGTH: int = 512
|
||||||
|
|
||||||
# Processing settings
|
# Processing settings
|
||||||
CHUNK_SIZE: int = 1000
|
CHUNK_SIZE: int = 1000
|
||||||
@ -60,6 +66,7 @@ class Settings(BaseSettings):
|
|||||||
DEFAULT_K: int = 4
|
DEFAULT_K: int = 4
|
||||||
FRAME_SAMPLE_RATE: int = 120
|
FRAME_SAMPLE_RATE: int = 120
|
||||||
USE_UNSTRUCTURED_API: bool = False
|
USE_UNSTRUCTURED_API: bool = False
|
||||||
|
USE_RERANKING: bool = True
|
||||||
|
|
||||||
# Auth settings
|
# Auth settings
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
@ -87,6 +94,7 @@ def get_settings() -> Settings:
|
|||||||
"EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
|
"EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
|
||||||
"COMPLETION_PROVIDER": config["service"]["components"]["completion"],
|
"COMPLETION_PROVIDER": config["service"]["components"]["completion"],
|
||||||
"PARSER_PROVIDER": config["service"]["components"]["parser"],
|
"PARSER_PROVIDER": config["service"]["components"]["parser"],
|
||||||
|
"RERANKER_PROVIDER": config["service"]["components"]["reranker"],
|
||||||
# Storage settings
|
# Storage settings
|
||||||
"STORAGE_PATH": config["storage"]["local"]["path"],
|
"STORAGE_PATH": config["storage"]["local"]["path"],
|
||||||
"AWS_REGION": config["storage"]["aws"]["region"],
|
"AWS_REGION": config["storage"]["aws"]["region"],
|
||||||
@ -104,10 +112,16 @@ def get_settings() -> Settings:
|
|||||||
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
|
||||||
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
|
||||||
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
|
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
|
||||||
|
"RERANKER_MODEL": config["models"]["reranker"]["model_name"],
|
||||||
|
"RERANKER_DEVICE": config["models"]["reranker"].get("device"),
|
||||||
|
"RERANKER_USE_FP16": config["models"]["reranker"].get("use_fp16", True),
|
||||||
|
"RERANKER_QUERY_MAX_LENGTH": config["models"]["reranker"].get("query_max_length", 256),
|
||||||
|
"RERANKER_PASSAGE_MAX_LENGTH": config["models"]["reranker"].get("passage_max_length", 512),
|
||||||
# Processing settings
|
# Processing settings
|
||||||
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
|
||||||
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
|
||||||
"DEFAULT_K": config["processing"]["text"]["default_k"],
|
"DEFAULT_K": config["processing"]["text"]["default_k"],
|
||||||
|
"USE_RERANKING": config["processing"]["text"]["use_reranking"],
|
||||||
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
|
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
|
||||||
"USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"],
|
"USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"],
|
||||||
# Auth settings
|
# Auth settings
|
||||||
|
@ -52,6 +52,7 @@ class MongoDatabase(BaseDatabase):
|
|||||||
# Ensure system metadata
|
# Ensure system metadata
|
||||||
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
|
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
|
||||||
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
|
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
|
||||||
|
doc_dict["metadata"]["external_id"] = doc_dict["external_id"]
|
||||||
|
|
||||||
result = await self.collection.insert_one(doc_dict)
|
result = await self.collection.insert_one(doc_dict)
|
||||||
return bool(result.inserted_id)
|
return bool(result.inserted_id)
|
||||||
|
@ -16,6 +16,7 @@ class RetrieveRequest(BaseModel):
|
|||||||
filters: Optional[Dict[str, Any]] = None
|
filters: Optional[Dict[str, Any]] = None
|
||||||
k: int = Field(default=4, gt=0)
|
k: int = Field(default=4, gt=0)
|
||||||
min_score: float = Field(default=0.0)
|
min_score: float = Field(default=0.0)
|
||||||
|
use_reranking: Optional[bool] = None # If None, use default from config
|
||||||
|
|
||||||
|
|
||||||
class CompletionQueryRequest(RetrieveRequest):
|
class CompletionQueryRequest(RetrieveRequest):
|
||||||
|
1
core/reranker/__init__.py
Normal file
1
core/reranker/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Reranker package for reranking search results."""
|
26
core/reranker/base_reranker.py
Normal file
26
core/reranker/base_reranker.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from core.models.chunk import DocumentChunk
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReranker(ABC):
|
||||||
|
"""Base class for reranking search results"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
chunks: List[DocumentChunk],
|
||||||
|
) -> List[DocumentChunk]:
|
||||||
|
"""Rerank chunks based on their relevance to the query"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def compute_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
) -> Union[float, List[float]]:
|
||||||
|
"""Compute relevance scores between query and text"""
|
||||||
|
pass
|
59
core/reranker/bge_reranker.py
Normal file
59
core/reranker/bge_reranker.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from typing import List, Union, Optional
|
||||||
|
from FlagEmbedding import FlagAutoReranker
|
||||||
|
|
||||||
|
from core.models.chunk import DocumentChunk
|
||||||
|
from core.reranker.base_reranker import BaseReranker
|
||||||
|
|
||||||
|
|
||||||
|
class BGEReranker(BaseReranker):
|
||||||
|
"""BGE reranker implementation using FlagEmbedding"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "BAAI/bge-reranker-v2-gemma",
|
||||||
|
query_max_length: int = 256,
|
||||||
|
passage_max_length: int = 512,
|
||||||
|
use_fp16: bool = True,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize BGE reranker"""
|
||||||
|
devices = [device] if device else None
|
||||||
|
self.reranker = FlagAutoReranker.from_finetuned(
|
||||||
|
model_name_or_path=model_name,
|
||||||
|
query_max_length=query_max_length,
|
||||||
|
passage_max_length=passage_max_length,
|
||||||
|
use_fp16=use_fp16,
|
||||||
|
devices=devices,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
chunks: List[DocumentChunk],
|
||||||
|
) -> List[DocumentChunk]:
|
||||||
|
"""Rerank chunks based on their relevance to the query"""
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get scores for all chunks
|
||||||
|
passages = [chunk.content for chunk in chunks]
|
||||||
|
scores = await self.compute_score(query, passages)
|
||||||
|
|
||||||
|
# Update scores and sort chunks
|
||||||
|
for chunk, score in zip(chunks, scores):
|
||||||
|
chunk.score = float(score)
|
||||||
|
|
||||||
|
return sorted(chunks, key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
async def compute_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
) -> Union[float, List[float]]:
|
||||||
|
"""Compute relevance scores between query and text"""
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
scores = self.reranker.compute_score([[query, t] for t in text], normalize=True)
|
||||||
|
return scores[0] if len(scores) == 1 else scores
|
||||||
|
else:
|
||||||
|
return self.reranker.compute_score([[query, t] for t in text], normalize=True)
|
@ -19,6 +19,8 @@ from core.parser.base_parser import BaseParser
|
|||||||
from core.completion.base_completion import BaseCompletionModel
|
from core.completion.base_completion import BaseCompletionModel
|
||||||
from core.completion.base_completion import CompletionRequest, CompletionResponse
|
from core.completion.base_completion import CompletionRequest, CompletionResponse
|
||||||
import logging
|
import logging
|
||||||
|
from core.reranker.base_reranker import BaseReranker
|
||||||
|
from core.config import get_settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,6 +34,7 @@ class DocumentService:
|
|||||||
parser: BaseParser,
|
parser: BaseParser,
|
||||||
embedding_model: BaseEmbeddingModel,
|
embedding_model: BaseEmbeddingModel,
|
||||||
completion_model: BaseCompletionModel,
|
completion_model: BaseCompletionModel,
|
||||||
|
reranker: BaseReranker,
|
||||||
):
|
):
|
||||||
self.db = database
|
self.db = database
|
||||||
self.vector_store = vector_store
|
self.vector_store = vector_store
|
||||||
@ -39,16 +42,21 @@ class DocumentService:
|
|||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.completion_model = completion_model
|
self.completion_model = completion_model
|
||||||
|
self.reranker = reranker
|
||||||
|
|
||||||
async def retrieve_chunks(
|
async def retrieve_chunks(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
auth: AuthContext,
|
auth: AuthContext,
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
k: int = 4,
|
k: int = 5,
|
||||||
min_score: float = 0.0,
|
min_score: float = 0.0,
|
||||||
|
use_reranking: Optional[bool] = None,
|
||||||
) -> List[ChunkResult]:
|
) -> List[ChunkResult]:
|
||||||
"""Retrieve relevant chunks."""
|
"""Retrieve relevant chunks."""
|
||||||
|
settings = get_settings()
|
||||||
|
should_rerank = use_reranking if use_reranking is not None else settings.USE_RERANKING
|
||||||
|
|
||||||
# Get embedding for query
|
# Get embedding for query
|
||||||
query_embedding = await self.embedding_model.embed_for_query(query)
|
query_embedding = await self.embedding_model.embed_for_query(query)
|
||||||
logger.info("Generated query embedding")
|
logger.info("Generated query embedding")
|
||||||
@ -61,9 +69,18 @@ class DocumentService:
|
|||||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||||
|
|
||||||
# Search chunks with vector similarity
|
# Search chunks with vector similarity
|
||||||
chunks = await self.vector_store.query_similar(query_embedding, k=k, doc_ids=doc_ids)
|
chunks = await self.vector_store.query_similar(
|
||||||
|
query_embedding, k=10 * k if should_rerank else k, doc_ids=doc_ids
|
||||||
|
)
|
||||||
logger.info(f"Found {len(chunks)} similar chunks")
|
logger.info(f"Found {len(chunks)} similar chunks")
|
||||||
|
|
||||||
|
# Rerank chunks using the reranker if enabled
|
||||||
|
if chunks and should_rerank:
|
||||||
|
chunks = await self.reranker.rerank(query, chunks)
|
||||||
|
chunks.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
chunks = chunks[:k]
|
||||||
|
logger.info(f"Reranked {k*10} chunks and selected the top {k}")
|
||||||
|
|
||||||
# Create and return chunk results
|
# Create and return chunk results
|
||||||
results = await self._create_chunk_results(auth, chunks)
|
results = await self._create_chunk_results(auth, chunks)
|
||||||
logger.info(f"Returning {len(results)} chunk results")
|
logger.info(f"Returning {len(results)} chunk results")
|
||||||
@ -74,12 +91,13 @@ class DocumentService:
|
|||||||
query: str,
|
query: str,
|
||||||
auth: AuthContext,
|
auth: AuthContext,
|
||||||
filters: Optional[Dict[str, Any]] = None,
|
filters: Optional[Dict[str, Any]] = None,
|
||||||
k: int = 4,
|
k: int = 5,
|
||||||
min_score: float = 0.0,
|
min_score: float = 0.0,
|
||||||
|
use_reranking: Optional[bool] = None,
|
||||||
) -> List[DocumentResult]:
|
) -> List[DocumentResult]:
|
||||||
"""Retrieve relevant documents."""
|
"""Retrieve relevant documents."""
|
||||||
# Get chunks first
|
# Get chunks first
|
||||||
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
|
||||||
# Convert to document results
|
# Convert to document results
|
||||||
results = await self._create_document_results(auth, chunks)
|
results = await self._create_document_results(auth, chunks)
|
||||||
documents = list(results.values())
|
documents = list(results.values())
|
||||||
@ -95,10 +113,11 @@ class DocumentService:
|
|||||||
min_score: float = 0.0,
|
min_score: float = 0.0,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
|
use_reranking: Optional[bool] = None,
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
"""Generate completion using relevant chunks as context."""
|
"""Generate completion using relevant chunks as context."""
|
||||||
# Get relevant chunks
|
# Get relevant chunks
|
||||||
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
|
||||||
documents = await self._create_document_results(auth, chunks)
|
documents = await self._create_document_results(auth, chunks)
|
||||||
|
|
||||||
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
|
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
|
||||||
|
@ -194,6 +194,10 @@ async def test_ingest_invalid_metadata(client: AsyncClient):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
get_settings().EMBEDDING_PROVIDER == "ollama",
|
||||||
|
reason="local embedding models do not have size limits",
|
||||||
|
)
|
||||||
async def test_ingest_oversized_content(client: AsyncClient):
|
async def test_ingest_oversized_content(client: AsyncClient):
|
||||||
"""Test ingestion with oversized content"""
|
"""Test ingestion with oversized content"""
|
||||||
headers = create_auth_header()
|
headers = create_auth_header()
|
||||||
@ -285,26 +289,28 @@ async def test_invalid_document_id(client: AsyncClient):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retrieve_chunks(client: AsyncClient):
|
async def test_retrieve_chunks(client: AsyncClient):
|
||||||
"""Test retrieving document chunks"""
|
"""Test retrieving document chunks"""
|
||||||
|
upload_string = "The quick brown fox jumps over the lazy dog"
|
||||||
# First ingest a document to search
|
# First ingest a document to search
|
||||||
doc_id = await test_ingest_text_document(
|
doc_id = await test_ingest_text_document(client, content=upload_string)
|
||||||
client, content="The quick brown fox jumps over the lazy dog"
|
|
||||||
)
|
|
||||||
|
|
||||||
headers = create_auth_header()
|
headers = create_auth_header()
|
||||||
# Sleep to allow time for document to be indexed
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/retrieve/chunks",
|
"/retrieve/chunks",
|
||||||
json={"query": "jumping fox", "k": 1, "min_score": 0.0},
|
json={
|
||||||
|
"query": "jumping fox",
|
||||||
|
"k": 1,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"filters": {"external_id": doc_id}, # Add filter for specific document
|
||||||
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
results = list(response.json())
|
results = list(response.json())
|
||||||
assert len(results) == 1
|
assert len(results) > 0
|
||||||
assert results[0]["score"] > 0.5
|
assert results[0]["score"] > 0.5
|
||||||
assert results[0]["document_id"] == doc_id
|
assert results[0]["content"] == upload_string
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -324,7 +330,10 @@ async def test_retrieve_docs(client: AsyncClient):
|
|||||||
headers = create_auth_header()
|
headers = create_auth_header()
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/retrieve/docs",
|
"/retrieve/docs",
|
||||||
json={"query": "Headaches, dehydration", "filters": {"test": True}},
|
json={
|
||||||
|
"query": "Headaches, dehydration",
|
||||||
|
"filters": {"test": True, "external_id": doc_id}, # Add filter for specific document
|
||||||
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,3 +407,287 @@ async def test_invalid_completion_params(client: AsyncClient):
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_chunks_default_reranking(client: AsyncClient):
|
||||||
|
"""Test retrieving chunks with default reranking behavior"""
|
||||||
|
# First ingest some test documents
|
||||||
|
_ = await test_ingest_text_document(
|
||||||
|
client, "The quick brown fox jumps over the lazy dog. This is a test document."
|
||||||
|
)
|
||||||
|
_ = await test_ingest_text_document(
|
||||||
|
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
response = await client.post(
|
||||||
|
"/retrieve/chunks",
|
||||||
|
json={
|
||||||
|
"query": "What does the fox do?",
|
||||||
|
"k": 2,
|
||||||
|
"min_score": 0.0,
|
||||||
|
# Not specifying use_reranking - should use default from config
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
chunks = response.json()
|
||||||
|
assert len(chunks) > 0
|
||||||
|
# Verify chunks are ordered by score
|
||||||
|
scores = [chunk["score"] for chunk in chunks]
|
||||||
|
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_chunks_explicit_reranking(client: AsyncClient):
|
||||||
|
"""Test retrieving chunks with explicitly enabled reranking"""
|
||||||
|
# First ingest some test documents
|
||||||
|
_ = await test_ingest_text_document(
|
||||||
|
client, "The quick brown fox jumps over the lazy dog. This is a test document."
|
||||||
|
)
|
||||||
|
_ = await test_ingest_text_document(
|
||||||
|
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
response = await client.post(
|
||||||
|
"/retrieve/chunks",
|
||||||
|
json={
|
||||||
|
"query": "What does the fox do?",
|
||||||
|
"k": 2,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": True,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
chunks = response.json()
|
||||||
|
assert len(chunks) > 0
|
||||||
|
# Verify chunks are ordered by score
|
||||||
|
scores = [chunk["score"] for chunk in chunks]
|
||||||
|
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_chunks_disabled_reranking(client: AsyncClient):
|
||||||
|
"""Test retrieving chunks with explicitly disabled reranking"""
|
||||||
|
# First ingest some test documents
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "The quick brown fox jumps over the lazy dog. This is a test document."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
response = await client.post(
|
||||||
|
"/retrieve/chunks",
|
||||||
|
json={
|
||||||
|
"query": "What does the fox do?",
|
||||||
|
"k": 2,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": False,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
chunks = response.json()
|
||||||
|
assert len(chunks) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranking_affects_results(client: AsyncClient):
|
||||||
|
"""Test that reranking actually changes the order of results"""
|
||||||
|
# First ingest documents with clearly different semantic relevance
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris is a city in France. It has many famous landmarks and museums."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
|
||||||
|
# Get results without reranking
|
||||||
|
response_no_rerank = await client.post(
|
||||||
|
"/retrieve/chunks",
|
||||||
|
json={
|
||||||
|
"query": "Tell me about the capital city of France",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": False,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get results with reranking
|
||||||
|
response_with_rerank = await client.post(
|
||||||
|
"/retrieve/chunks",
|
||||||
|
json={
|
||||||
|
"query": "Tell me about the capital city of France",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": True,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response_no_rerank.status_code == 200
|
||||||
|
assert response_with_rerank.status_code == 200
|
||||||
|
|
||||||
|
chunks_no_rerank = response_no_rerank.json()
|
||||||
|
chunks_with_rerank = response_with_rerank.json()
|
||||||
|
|
||||||
|
# Verify we got results in both cases
|
||||||
|
assert len(chunks_no_rerank) > 0
|
||||||
|
assert len(chunks_with_rerank) > 0
|
||||||
|
|
||||||
|
# The order or scores should be different between reranked and non-reranked results
|
||||||
|
# This test might be a bit flaky depending on the exact scoring, but it should work most of the time
|
||||||
|
# given our carefully crafted test data
|
||||||
|
scores_no_rerank = [c["score"] for c in chunks_no_rerank]
|
||||||
|
scores_with_rerank = [c["score"] for c in chunks_with_rerank]
|
||||||
|
assert scores_no_rerank != scores_with_rerank, "Reranking should affect the scores"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_docs_with_reranking(client: AsyncClient):
|
||||||
|
"""Test document retrieval with reranking options"""
|
||||||
|
# First ingest documents with clearly different semantic relevance
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris is a city in France. It has many famous landmarks and museums."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
|
||||||
|
# Test with default reranking (from config)
|
||||||
|
response_default = await client.post(
|
||||||
|
"/retrieve/docs",
|
||||||
|
json={
|
||||||
|
"query": "Tell me about the capital city of France",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_default.status_code == 200
|
||||||
|
docs_default = response_default.json()
|
||||||
|
assert len(docs_default) > 0
|
||||||
|
|
||||||
|
# Test with explicit reranking enabled
|
||||||
|
response_rerank = await client.post(
|
||||||
|
"/retrieve/docs",
|
||||||
|
json={
|
||||||
|
"query": "Tell me about the capital city of France",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": True,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_rerank.status_code == 200
|
||||||
|
docs_rerank = response_rerank.json()
|
||||||
|
assert len(docs_rerank) > 0
|
||||||
|
|
||||||
|
# Test with reranking disabled
|
||||||
|
response_no_rerank = await client.post(
|
||||||
|
"/retrieve/docs",
|
||||||
|
json={
|
||||||
|
"query": "Tell me about the capital city of France",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"use_reranking": False,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_no_rerank.status_code == 200
|
||||||
|
docs_no_rerank = response_no_rerank.json()
|
||||||
|
assert len(docs_no_rerank) > 0
|
||||||
|
|
||||||
|
# Verify that reranking affects the order
|
||||||
|
scores_rerank = [doc["score"] for doc in docs_rerank]
|
||||||
|
scores_no_rerank = [doc["score"] for doc in docs_no_rerank]
|
||||||
|
assert scores_rerank != scores_no_rerank, "Reranking should affect document scores"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_reranking(client: AsyncClient):
|
||||||
|
"""Test query completion with reranking options"""
|
||||||
|
# First ingest documents with clearly different semantic relevance
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris is a city in France. It has many famous landmarks and museums."
|
||||||
|
)
|
||||||
|
await test_ingest_text_document(
|
||||||
|
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = create_auth_header()
|
||||||
|
|
||||||
|
# Test with default reranking (from config)
|
||||||
|
response_default = await client.post(
|
||||||
|
"/query",
|
||||||
|
json={
|
||||||
|
"query": "What is the capital of France?",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_default.status_code == 200
|
||||||
|
completion_default = response_default.json()
|
||||||
|
assert "completion" in completion_default
|
||||||
|
|
||||||
|
# Test with explicit reranking enabled
|
||||||
|
response_rerank = await client.post(
|
||||||
|
"/query",
|
||||||
|
json={
|
||||||
|
"query": "What is the capital of France?",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"max_tokens": 50,
|
||||||
|
"use_reranking": True,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_rerank.status_code == 200
|
||||||
|
completion_rerank = response_rerank.json()
|
||||||
|
assert "completion" in completion_rerank
|
||||||
|
|
||||||
|
# Test with reranking disabled
|
||||||
|
response_no_rerank = await client.post(
|
||||||
|
"/query",
|
||||||
|
json={
|
||||||
|
"query": "What is the capital of France?",
|
||||||
|
"k": 3,
|
||||||
|
"min_score": 0.0,
|
||||||
|
"max_tokens": 50,
|
||||||
|
"use_reranking": False,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response_no_rerank.status_code == 200
|
||||||
|
completion_no_rerank = response_no_rerank.json()
|
||||||
|
assert "completion" in completion_no_rerank
|
||||||
|
|
||||||
|
# The actual responses might be different due to different chunk ordering,
|
||||||
|
# but all should mention Paris as the capital
|
||||||
|
assert "Paris" in completion_default["completion"]
|
||||||
|
assert "Paris" in completion_rerank["completion"]
|
||||||
|
assert "Paris" in completion_no_rerank["completion"]
|
||||||
|
166
core/tests/unit/test_reranker.py
Normal file
166
core/tests/unit/test_reranker.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from core.models.chunk import DocumentChunk
|
||||||
|
from core.reranker.bge_reranker import BGEReranker
|
||||||
|
from core.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_chunks() -> List[DocumentChunk]:
|
||||||
|
return [
|
||||||
|
DocumentChunk(
|
||||||
|
document_id="1",
|
||||||
|
content="The quick brown fox jumps over the lazy dog",
|
||||||
|
embedding=[0.1] * 10,
|
||||||
|
chunk_number=1,
|
||||||
|
score=0.5,
|
||||||
|
),
|
||||||
|
DocumentChunk(
|
||||||
|
document_id="2",
|
||||||
|
content="Python is a popular programming language",
|
||||||
|
embedding=[0.2] * 10,
|
||||||
|
chunk_number=1,
|
||||||
|
score=0.7,
|
||||||
|
),
|
||||||
|
DocumentChunk(
|
||||||
|
document_id="3",
|
||||||
|
content="Machine learning models help analyze data",
|
||||||
|
embedding=[0.3] * 10,
|
||||||
|
chunk_number=1,
|
||||||
|
score=0.3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reranker():
|
||||||
|
"""Fixture to create and reuse a BGE reranker instance"""
|
||||||
|
settings = get_settings()
|
||||||
|
return BGEReranker(
|
||||||
|
model_name=settings.RERANKER_MODEL,
|
||||||
|
device=settings.RERANKER_DEVICE,
|
||||||
|
use_fp16=settings.RERANKER_USE_FP16,
|
||||||
|
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
|
||||||
|
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_relevance(reranker, sample_chunks):
|
||||||
|
"""Test that reranker improves relevance for programming-related query"""
|
||||||
|
print("\n=== Testing Reranker Relevance ===")
|
||||||
|
query = "What is Python programming language?"
|
||||||
|
|
||||||
|
# Get reranked results
|
||||||
|
reranked_chunks = await reranker.rerank(query, sample_chunks)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
for i, chunk in enumerate(reranked_chunks):
|
||||||
|
print(f"{i+1}. Score: {chunk.score:.3f} - {chunk.content}")
|
||||||
|
|
||||||
|
# The most relevant chunks should be about Python
|
||||||
|
assert "Python" in reranked_chunks[0].content
|
||||||
|
assert reranked_chunks[0].score > reranked_chunks[-1].score
|
||||||
|
|
||||||
|
# Check that irrelevant content (fox/dog) is ranked lower
|
||||||
|
fox_chunk_idx = next(i for i, c in enumerate(reranked_chunks) if "fox" in c.content.lower())
|
||||||
|
assert fox_chunk_idx > 0 # Should not be first
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_score_distribution(reranker, sample_chunks):
|
||||||
|
"""Test that reranker produces reasonable score distribution"""
|
||||||
|
print("\n=== Testing Score Distribution ===")
|
||||||
|
query = "Tell me about machine learning and data science"
|
||||||
|
|
||||||
|
# Get reranked results
|
||||||
|
reranked_chunks = await reranker.rerank(query, sample_chunks)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
for i, chunk in enumerate(reranked_chunks):
|
||||||
|
print(f"{i+1}. Score: {chunk.score:.3f} - {chunk.content}")
|
||||||
|
|
||||||
|
# Check score properties
|
||||||
|
scores = [c.score for c in reranked_chunks]
|
||||||
|
assert all(0 <= s <= 1 for s in scores) # Scores should be between 0 and 1
|
||||||
|
assert len(set(scores)) > 1 # Should have different scores (not all same)
|
||||||
|
|
||||||
|
# Verify ordering
|
||||||
|
assert scores == sorted(scores, reverse=True) # Should be in descending order
|
||||||
|
|
||||||
|
# Most relevant chunk should be about ML/data science
|
||||||
|
top_chunk = reranked_chunks[0]
|
||||||
|
assert any(term in top_chunk.content.lower() for term in ["machine learning", "data science"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_batch_scoring(reranker):
|
||||||
|
"""Test that reranker can handle multiple queries/passages efficiently"""
|
||||||
|
print("\n=== Testing Batch Scoring ===")
|
||||||
|
texts = [
|
||||||
|
"Python is a programming language",
|
||||||
|
"Machine learning is a field of AI",
|
||||||
|
"The quick brown fox jumps",
|
||||||
|
"Data science uses statistical methods",
|
||||||
|
]
|
||||||
|
queries = ["What is Python?", "Explain artificial intelligence", "Tell me about data analysis"]
|
||||||
|
|
||||||
|
# Test multiple queries against multiple texts
|
||||||
|
for query in queries:
|
||||||
|
scores = await reranker.compute_score(query, texts)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
for text, score in zip(texts, scores):
|
||||||
|
print(f"Score: {score:.3f} - {text}")
|
||||||
|
assert len(scores) == len(texts)
|
||||||
|
assert all(isinstance(s, float) for s in scores)
|
||||||
|
assert all(0 <= s <= 1 for s in scores)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
|
||||||
|
"""Test reranker behavior with empty or edge case inputs"""
|
||||||
|
print("\n=== Testing Edge Cases ===")
|
||||||
|
|
||||||
|
# Empty chunks list
|
||||||
|
result = await reranker.rerank("test query", [])
|
||||||
|
assert result == []
|
||||||
|
print("Empty chunks test passed")
|
||||||
|
|
||||||
|
# Single chunk
|
||||||
|
single_chunk = DocumentChunk(
|
||||||
|
document_id="1",
|
||||||
|
content="Test content",
|
||||||
|
embedding=[0.1] * 768,
|
||||||
|
chunk_number=1,
|
||||||
|
score=0.5,
|
||||||
|
)
|
||||||
|
result = await reranker.rerank("test query", [single_chunk])
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0].score, float)
|
||||||
|
print(f"Single chunk test passed - Score: {result[0].score:.3f}")
|
||||||
|
|
||||||
|
# Empty query
|
||||||
|
result = await reranker.rerank("", sample_chunks)
|
||||||
|
assert len(result) == len(sample_chunks)
|
||||||
|
print("Empty query test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reranker_consistency(reranker, sample_chunks):
|
||||||
|
"""Test that reranker produces consistent results for same input"""
|
||||||
|
print("\n=== Testing Consistency ===")
|
||||||
|
query = "What is Python programming?"
|
||||||
|
|
||||||
|
# Run reranking multiple times
|
||||||
|
results1 = await reranker.rerank(query, sample_chunks)
|
||||||
|
results2 = await reranker.rerank(query, sample_chunks)
|
||||||
|
|
||||||
|
# Scores should be the same across runs
|
||||||
|
scores1 = [c.score for c in results1]
|
||||||
|
scores2 = [c.score for c in results2]
|
||||||
|
print("\nScores from first run:", [f"{s:.3f}" for s in scores1])
|
||||||
|
print("Scores from second run:", [f"{s:.3f}" for s in scores2])
|
||||||
|
assert scores1 == scores2
|
||||||
|
|
||||||
|
# Order should be preserved
|
||||||
|
assert [c.document_id for c in results1] == [c.document_id for c in results2]
|
||||||
|
print("Order consistency test passed")
|
@ -1,10 +1,13 @@
|
|||||||
|
accelerate==1.2.1
|
||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
annotated-types==0.6.0
|
annotated-types==0.6.0
|
||||||
|
anthropic==0.42.0
|
||||||
antlr4-python3-runtime==4.9.3
|
antlr4-python3-runtime==4.9.3
|
||||||
anyio==4.3.0
|
anyio==4.3.0
|
||||||
appnope==0.1.4
|
appnope==0.1.4
|
||||||
|
asgiref==3.8.1
|
||||||
assemblyai==0.36.0
|
assemblyai==0.36.0
|
||||||
asttokens==2.4.1
|
asttokens==2.4.1
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
@ -18,6 +21,7 @@ botocore==1.34.103
|
|||||||
botocore-stubs==1.34.150
|
botocore-stubs==1.34.150
|
||||||
build==1.2.2.post1
|
build==1.2.2.post1
|
||||||
cachetools==5.3.3
|
cachetools==5.3.3
|
||||||
|
cbor==1.0.0
|
||||||
certifi==2024.2.2
|
certifi==2024.2.2
|
||||||
cffi==1.17.0
|
cffi==1.17.0
|
||||||
cfgv==3.4.0
|
cfgv==3.4.0
|
||||||
@ -30,7 +34,7 @@ contourpy==1.2.1
|
|||||||
cryptography==43.0.0
|
cryptography==43.0.0
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
datasets==2.21.0
|
datasets==2.19.0
|
||||||
debugpy==1.8.5
|
debugpy==1.8.5
|
||||||
decorator==5.1.1
|
decorator==5.1.1
|
||||||
deepdiff==7.0.1
|
deepdiff==7.0.1
|
||||||
@ -51,10 +55,11 @@ fastapi-cli==0.0.2
|
|||||||
ffmpeg-python==0.2.0
|
ffmpeg-python==0.2.0
|
||||||
filelock==3.15.4
|
filelock==3.15.4
|
||||||
filetype==1.2.0
|
filetype==1.2.0
|
||||||
|
FlagEmbedding==1.3.3
|
||||||
flatbuffers==24.3.25
|
flatbuffers==24.3.25
|
||||||
fonttools==4.53.1
|
fonttools==4.53.1
|
||||||
frozenlist==1.4.1
|
frozenlist==1.4.1
|
||||||
fsspec==2024.6.1
|
fsspec==2024.3.1
|
||||||
future==1.0.0
|
future==1.0.0
|
||||||
google-api-core==2.19.1
|
google-api-core==2.19.1
|
||||||
google-auth==2.29.0
|
google-auth==2.29.0
|
||||||
@ -67,14 +72,18 @@ h11==0.14.0
|
|||||||
httpcore==1.0.5
|
httpcore==1.0.5
|
||||||
httptools==0.6.1
|
httptools==0.6.1
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
huggingface-hub==0.24.5
|
huggingface-hub==0.27.0
|
||||||
humanfriendly==10.0
|
humanfriendly==10.0
|
||||||
identify==2.6.3
|
identify==2.6.3
|
||||||
idna==3.7
|
idna==3.7
|
||||||
|
ijson==3.3.0
|
||||||
|
importlib_metadata==8.5.0
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
|
inscriptis==2.5.0
|
||||||
iopath==0.1.10
|
iopath==0.1.10
|
||||||
ipykernel==6.29.5
|
ipykernel==6.29.5
|
||||||
ipython==8.26.0
|
ipython==8.26.0
|
||||||
|
ir_datasets==0.5.9
|
||||||
jaraco.classes==3.4.0
|
jaraco.classes==3.4.0
|
||||||
jaraco.context==6.0.1
|
jaraco.context==6.0.1
|
||||||
jaraco.functools==4.1.0
|
jaraco.functools==4.1.0
|
||||||
@ -99,9 +108,11 @@ langchain-text-splitters==0.2.2
|
|||||||
langchain-unstructured==0.1.1
|
langchain-unstructured==0.1.1
|
||||||
langdetect==1.0.9
|
langdetect==1.0.9
|
||||||
langsmith==0.1.98
|
langsmith==0.1.98
|
||||||
|
lap==0.5.12
|
||||||
layoutparser==0.3.4
|
layoutparser==0.3.4
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
lxml==5.2.2
|
lxml==5.2.2
|
||||||
|
lz4==4.3.3
|
||||||
Markdown==3.6
|
Markdown==3.6
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
MarkupSafe==2.1.5
|
MarkupSafe==2.1.5
|
||||||
@ -131,6 +142,18 @@ onnxruntime==1.18.1
|
|||||||
openai==1.43.0
|
openai==1.43.0
|
||||||
opencv-python==4.10.0.84
|
opencv-python==4.10.0.84
|
||||||
openpyxl==3.1.5
|
openpyxl==3.1.5
|
||||||
|
opentelemetry-api==1.29.0
|
||||||
|
opentelemetry-exporter-otlp==1.29.0
|
||||||
|
opentelemetry-exporter-otlp-proto-common==1.29.0
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc==1.29.0
|
||||||
|
opentelemetry-exporter-otlp-proto-http==1.29.0
|
||||||
|
opentelemetry-instrumentation==0.50b0
|
||||||
|
opentelemetry-instrumentation-asgi==0.50b0
|
||||||
|
opentelemetry-instrumentation-fastapi==0.50b0
|
||||||
|
opentelemetry-proto==1.29.0
|
||||||
|
opentelemetry-sdk==1.29.0
|
||||||
|
opentelemetry-semantic-conventions==0.50b0
|
||||||
|
opentelemetry-util-http==0.50b0
|
||||||
ordered-set==4.1.0
|
ordered-set==4.1.0
|
||||||
orjson==3.10.3
|
orjson==3.10.3
|
||||||
packaging==24.0
|
packaging==24.0
|
||||||
@ -141,6 +164,7 @@ pathspec==0.12.1
|
|||||||
pdf2image==1.17.0
|
pdf2image==1.17.0
|
||||||
pdfminer.six==20231228
|
pdfminer.six==20231228
|
||||||
pdfplumber==0.11.3
|
pdfplumber==0.11.3
|
||||||
|
peft==0.14.0
|
||||||
pexpect==4.9.0
|
pexpect==4.9.0
|
||||||
pikepdf==9.1.1
|
pikepdf==9.1.1
|
||||||
pillow==10.4.0
|
pillow==10.4.0
|
||||||
@ -156,7 +180,9 @@ protobuf==5.27.3
|
|||||||
psutil==6.0.0
|
psutil==6.0.0
|
||||||
ptyprocess==0.7.0
|
ptyprocess==0.7.0
|
||||||
pure_eval==0.2.3
|
pure_eval==0.2.3
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
pyarrow==17.0.0
|
pyarrow==17.0.0
|
||||||
|
pyarrow-hotfix==0.6
|
||||||
pyasn1==0.6.0
|
pyasn1==0.6.0
|
||||||
pyasn1_modules==0.4.0
|
pyasn1_modules==0.4.0
|
||||||
pycocotools==2.0.8
|
pycocotools==2.0.8
|
||||||
@ -184,6 +210,7 @@ python-magic==0.4.27
|
|||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
python-oxmsg==0.0.1
|
python-oxmsg==0.0.1
|
||||||
python-pptx==0.6.23
|
python-pptx==0.6.23
|
||||||
|
pytubefix==8.8.5
|
||||||
pytz==2024.1
|
pytz==2024.1
|
||||||
PyYAML==6.0.1
|
PyYAML==6.0.1
|
||||||
pyzmq==26.2.0
|
pyzmq==26.2.0
|
||||||
@ -198,7 +225,11 @@ rich==13.7.1
|
|||||||
rsa==4.9
|
rsa==4.9
|
||||||
s3transfer==0.10.1
|
s3transfer==0.10.1
|
||||||
safetensors==0.4.4
|
safetensors==0.4.4
|
||||||
|
scikit-learn==1.6.0
|
||||||
scipy==1.14.0
|
scipy==1.14.0
|
||||||
|
seaborn==0.13.2
|
||||||
|
sentence-transformers==3.3.1
|
||||||
|
sentencepiece==0.2.0
|
||||||
shellingham==1.5.4
|
shellingham==1.5.4
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
@ -209,6 +240,7 @@ starlette==0.37.2
|
|||||||
sympy==1.13.2
|
sympy==1.13.2
|
||||||
tabulate==0.9.0
|
tabulate==0.9.0
|
||||||
tenacity==8.5.0
|
tenacity==8.5.0
|
||||||
|
threadpoolctl==3.5.0
|
||||||
tiktoken==0.7.0
|
tiktoken==0.7.0
|
||||||
timm==1.0.8
|
timm==1.0.8
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
@ -218,7 +250,8 @@ torchvision==0.17.2
|
|||||||
tornado==6.4.1
|
tornado==6.4.1
|
||||||
tqdm==4.66.4
|
tqdm==4.66.4
|
||||||
traitlets==5.14.3
|
traitlets==5.14.3
|
||||||
transformers==4.44.0
|
transformers==4.44.2
|
||||||
|
trec-car-tools==2.6
|
||||||
twine==6.0.1
|
twine==6.0.1
|
||||||
typer==0.12.3
|
typer==0.12.3
|
||||||
types-awscrt==0.21.2
|
types-awscrt==0.21.2
|
||||||
@ -227,6 +260,9 @@ typing-inspect==0.9.0
|
|||||||
typing_extensions==4.11.0
|
typing_extensions==4.11.0
|
||||||
tzdata==2024.1
|
tzdata==2024.1
|
||||||
ujson==5.9.0
|
ujson==5.9.0
|
||||||
|
ultralytics==8.3.55
|
||||||
|
ultralytics-thop==2.0.13
|
||||||
|
unlzw3==0.2.3
|
||||||
unstructured==0.15.0
|
unstructured==0.15.0
|
||||||
unstructured-client==0.24.1
|
unstructured-client==0.24.1
|
||||||
unstructured-inference==0.7.36
|
unstructured-inference==0.7.36
|
||||||
@ -235,6 +271,8 @@ urllib3==2.2.1
|
|||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
uvloop==0.19.0
|
uvloop==0.19.0
|
||||||
virtualenv==20.28.0
|
virtualenv==20.28.0
|
||||||
|
warc3-wet==0.2.5
|
||||||
|
warc3-wet-clueweb09==0.2.5
|
||||||
watchfiles==0.21.0
|
watchfiles==0.21.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
@ -243,7 +281,5 @@ xlrd==2.0.1
|
|||||||
XlsxWriter==3.2.0
|
XlsxWriter==3.2.0
|
||||||
xxhash==3.4.1
|
xxhash==3.4.1
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
opentelemetry-api>=1.21.0
|
zipp==3.21.0
|
||||||
opentelemetry-sdk>=1.21.0
|
zlib-state==0.1.9
|
||||||
opentelemetry-instrumentation-fastapi>=0.42b0
|
|
||||||
opentelemetry-exporter-otlp>=1.21.0
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user