Add reranking (#14)

This commit is contained in:
Arnav Agrawal 2025-01-02 03:42:47 -05:00 committed by GitHub
parent c936aa91a4
commit 20faae8903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 728 additions and 61 deletions

View File

@ -11,6 +11,7 @@ vector_store = "mongodb"
embedding = "ollama" # "openai", "ollama"
completion = "ollama" # "openai", "ollama"
parser = "combined" # "combined", "unstructured", "contextual"
reranker = "bge" # "bge"
# Storage Configuration
[storage.local]
@ -45,12 +46,20 @@ default_temperature = 0.7
[models.ollama]
base_url = "http://localhost:11434"
[models.reranker]
model_name = "BAAI/bge-reranker-large" # "BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-large"
device = "mps" # "cuda:0" # Optional: Set to null or remove for CPU
use_fp16 = true
query_max_length = 256
passage_max_length = 512
# Document Processing
[processing]
[processing.text]
chunk_size = 1000
chunk_overlap = 200
default_k = 4
use_reranking = true # Whether to use reranking by default
[processing.video]
frame_sample_rate = 120

View File

@ -28,6 +28,7 @@ from core.storage.local_storage import LocalStorage
from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
from core.completion.ollama_completion import OllamaCompletionModel
from core.parser.contextual_parser import ContextualParser
from core.reranker.bge_reranker import BGEReranker
# Initialize FastAPI app
app = FastAPI(title="DataBridge API")
@ -158,14 +159,28 @@ match settings.COMPLETION_PROVIDER:
case _:
raise ValueError(f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}")
# Initialize reranker
match settings.RERANKER_PROVIDER:
case "bge":
reranker = BGEReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
)
case _:
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
# Initialize document service with configured components
document_service = DocumentService(
storage=storage,
database=database,
vector_store=vector_store,
storage=storage,
parser=parser,
embedding_model=embedding_model,
completion_model=completion_model,
parser=parser,
reranker=reranker,
)
@ -243,27 +258,51 @@ async def ingest_file(
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant chunks."""
async with telemetry.track_operation(
operation_type="retrieve_chunks",
user_id=auth.entity_id,
metadata=request.model_dump(),
):
return await document_service.retrieve_chunks(
request.query, auth, request.filters, request.k, request.min_score
)
try:
async with telemetry.track_operation(
operation_type="retrieve_chunks",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"use_reranking": request.use_reranking,
},
):
return await document_service.retrieve_chunks(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant documents."""
async with telemetry.track_operation(
operation_type="retrieve_docs",
user_id=auth.entity_id,
metadata=request.model_dump(),
):
return await document_service.retrieve_docs(
request.query, auth, request.filters, request.k, request.min_score
)
try:
async with telemetry.track_operation(
operation_type="retrieve_docs",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"use_reranking": request.use_reranking,
},
):
return await document_service.retrieve_docs(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.use_reranking,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/query", response_model=CompletionResponse)
@ -271,27 +310,30 @@ async def query_completion(
request: CompletionQueryRequest, auth: AuthContext = Depends(verify_token)
):
"""Generate completion using relevant chunks as context."""
async with telemetry.track_operation(
operation_type="query",
user_id=auth.entity_id,
metadata=request.model_dump(),
) as span:
response = await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
)
if isinstance(response, dict) and "usage" in response:
usage = response["usage"]
if isinstance(usage, dict):
span.set_attribute("tokens.completion", usage.get("completion_tokens", 0))
span.set_attribute("tokens.prompt", usage.get("prompt_tokens", 0))
span.set_attribute("tokens.total", usage.get("total_tokens", 0))
return response
try:
async with telemetry.track_operation(
operation_type="query",
user_id=auth.entity_id,
metadata={
"k": request.k,
"min_score": request.min_score,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"use_reranking": request.use_reranking,
},
):
return await document_service.query(
request.query,
auth,
request.filters,
request.k,
request.min_score,
request.max_tokens,
request.temperature,
request.use_reranking,
)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/documents", response_model=List[Document])

View File

@ -32,6 +32,7 @@ class Settings(BaseSettings):
EMBEDDING_PROVIDER: str = "openai"
COMPLETION_PROVIDER: str = "ollama"
PARSER_PROVIDER: str = "combined"
RERANKER_PROVIDER: str = "bge"
# Storage settings
STORAGE_PATH: str = "./storage"
@ -53,6 +54,11 @@ class Settings(BaseSettings):
COMPLETION_MAX_TOKENS: int = 1000
COMPLETION_TEMPERATURE: float = 0.7
OLLAMA_BASE_URL: str = "http://localhost:11434"
RERANKER_MODEL: str = "BAAI/bge-reranker-v2-gemma"
RERANKER_DEVICE: Optional[str] = None
RERANKER_USE_FP16: bool = True
RERANKER_QUERY_MAX_LENGTH: int = 256
RERANKER_PASSAGE_MAX_LENGTH: int = 512
# Processing settings
CHUNK_SIZE: int = 1000
@ -60,6 +66,7 @@ class Settings(BaseSettings):
DEFAULT_K: int = 4
FRAME_SAMPLE_RATE: int = 120
USE_UNSTRUCTURED_API: bool = False
USE_RERANKING: bool = True
# Auth settings
JWT_ALGORITHM: str = "HS256"
@ -87,6 +94,7 @@ def get_settings() -> Settings:
"EMBEDDING_PROVIDER": config["service"]["components"]["embedding"],
"COMPLETION_PROVIDER": config["service"]["components"]["completion"],
"PARSER_PROVIDER": config["service"]["components"]["parser"],
"RERANKER_PROVIDER": config["service"]["components"]["reranker"],
# Storage settings
"STORAGE_PATH": config["storage"]["local"]["path"],
"AWS_REGION": config["storage"]["aws"]["region"],
@ -104,10 +112,16 @@ def get_settings() -> Settings:
"COMPLETION_MAX_TOKENS": config["models"]["completion"]["default_max_tokens"],
"COMPLETION_TEMPERATURE": config["models"]["completion"]["default_temperature"],
"OLLAMA_BASE_URL": config["models"]["ollama"]["base_url"],
"RERANKER_MODEL": config["models"]["reranker"]["model_name"],
"RERANKER_DEVICE": config["models"]["reranker"].get("device"),
"RERANKER_USE_FP16": config["models"]["reranker"].get("use_fp16", True),
"RERANKER_QUERY_MAX_LENGTH": config["models"]["reranker"].get("query_max_length", 256),
"RERANKER_PASSAGE_MAX_LENGTH": config["models"]["reranker"].get("passage_max_length", 512),
# Processing settings
"CHUNK_SIZE": config["processing"]["text"]["chunk_size"],
"CHUNK_OVERLAP": config["processing"]["text"]["chunk_overlap"],
"DEFAULT_K": config["processing"]["text"]["default_k"],
"USE_RERANKING": config["processing"]["text"]["use_reranking"],
"FRAME_SAMPLE_RATE": config["processing"]["video"]["frame_sample_rate"],
"USE_UNSTRUCTURED_API": config["processing"]["unstructured"]["use_api"],
# Auth settings

View File

@ -52,6 +52,7 @@ class MongoDatabase(BaseDatabase):
# Ensure system metadata
doc_dict["system_metadata"]["created_at"] = datetime.now(UTC)
doc_dict["system_metadata"]["updated_at"] = datetime.now(UTC)
doc_dict["metadata"]["external_id"] = doc_dict["external_id"]
result = await self.collection.insert_one(doc_dict)
return bool(result.inserted_id)

View File

@ -16,6 +16,7 @@ class RetrieveRequest(BaseModel):
filters: Optional[Dict[str, Any]] = None
k: int = Field(default=4, gt=0)
min_score: float = Field(default=0.0)
use_reranking: Optional[bool] = None # If None, use default from config
class CompletionQueryRequest(RetrieveRequest):

View File

@ -0,0 +1 @@
"""Reranker package for reranking search results."""

View File

@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from typing import List, Union
from core.models.chunk import DocumentChunk
class BaseReranker(ABC):
"""Base class for reranking search results"""
@abstractmethod
async def rerank(
self,
query: str,
chunks: List[DocumentChunk],
) -> List[DocumentChunk]:
"""Rerank chunks based on their relevance to the query"""
pass
@abstractmethod
async def compute_score(
self,
query: str,
text: Union[str, List[str]],
) -> Union[float, List[float]]:
"""Compute relevance scores between query and text"""
pass

View File

@ -0,0 +1,59 @@
from typing import List, Union, Optional
from FlagEmbedding import FlagAutoReranker
from core.models.chunk import DocumentChunk
from core.reranker.base_reranker import BaseReranker
class BGEReranker(BaseReranker):
"""BGE reranker implementation using FlagEmbedding"""
def __init__(
self,
model_name: str = "BAAI/bge-reranker-v2-gemma",
query_max_length: int = 256,
passage_max_length: int = 512,
use_fp16: bool = True,
device: Optional[str] = None,
):
"""Initialize BGE reranker"""
devices = [device] if device else None
self.reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_name,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
use_fp16=use_fp16,
devices=devices,
)
async def rerank(
self,
query: str,
chunks: List[DocumentChunk],
) -> List[DocumentChunk]:
"""Rerank chunks based on their relevance to the query"""
if not chunks:
return []
# Get scores for all chunks
passages = [chunk.content for chunk in chunks]
scores = await self.compute_score(query, passages)
# Update scores and sort chunks
for chunk, score in zip(chunks, scores):
chunk.score = float(score)
return sorted(chunks, key=lambda x: x.score, reverse=True)
async def compute_score(
self,
query: str,
text: Union[str, List[str]],
) -> Union[float, List[float]]:
"""Compute relevance scores between query and text"""
if isinstance(text, str):
text = [text]
scores = self.reranker.compute_score([[query, t] for t in text], normalize=True)
return scores[0] if len(scores) == 1 else scores
else:
return self.reranker.compute_score([[query, t] for t in text], normalize=True)

View File

@ -19,6 +19,8 @@ from core.parser.base_parser import BaseParser
from core.completion.base_completion import BaseCompletionModel
from core.completion.base_completion import CompletionRequest, CompletionResponse
import logging
from core.reranker.base_reranker import BaseReranker
from core.config import get_settings
logger = logging.getLogger(__name__)
@ -32,6 +34,7 @@ class DocumentService:
parser: BaseParser,
embedding_model: BaseEmbeddingModel,
completion_model: BaseCompletionModel,
reranker: BaseReranker,
):
self.db = database
self.vector_store = vector_store
@ -39,16 +42,21 @@ class DocumentService:
self.parser = parser
self.embedding_model = embedding_model
self.completion_model = completion_model
self.reranker = reranker
async def retrieve_chunks(
self,
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
k: int = 5,
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
) -> List[ChunkResult]:
"""Retrieve relevant chunks."""
settings = get_settings()
should_rerank = use_reranking if use_reranking is not None else settings.USE_RERANKING
# Get embedding for query
query_embedding = await self.embedding_model.embed_for_query(query)
logger.info("Generated query embedding")
@ -61,9 +69,18 @@ class DocumentService:
logger.info(f"Found {len(doc_ids)} authorized documents")
# Search chunks with vector similarity
chunks = await self.vector_store.query_similar(query_embedding, k=k, doc_ids=doc_ids)
chunks = await self.vector_store.query_similar(
query_embedding, k=10 * k if should_rerank else k, doc_ids=doc_ids
)
logger.info(f"Found {len(chunks)} similar chunks")
# Rerank chunks using the reranker if enabled
if chunks and should_rerank:
chunks = await self.reranker.rerank(query, chunks)
chunks.sort(key=lambda x: x.score, reverse=True)
chunks = chunks[:k]
logger.info(f"Reranked {k*10} chunks and selected the top {k}")
# Create and return chunk results
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results")
@ -74,12 +91,13 @@ class DocumentService:
query: str,
auth: AuthContext,
filters: Optional[Dict[str, Any]] = None,
k: int = 4,
k: int = 5,
min_score: float = 0.0,
use_reranking: Optional[bool] = None,
) -> List[DocumentResult]:
"""Retrieve relevant documents."""
# Get chunks first
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
# Convert to document results
results = await self._create_document_results(auth, chunks)
documents = list(results.values())
@ -95,10 +113,11 @@ class DocumentService:
min_score: float = 0.0,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
use_reranking: Optional[bool] = None,
) -> CompletionResponse:
"""Generate completion using relevant chunks as context."""
# Get relevant chunks
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
documents = await self._create_document_results(auth, chunks)
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]

View File

@ -194,6 +194,10 @@ async def test_ingest_invalid_metadata(client: AsyncClient):
@pytest.mark.asyncio
@pytest.mark.skipif(
get_settings().EMBEDDING_PROVIDER == "ollama",
reason="local embedding models do not have size limits",
)
async def test_ingest_oversized_content(client: AsyncClient):
"""Test ingestion with oversized content"""
headers = create_auth_header()
@ -285,26 +289,28 @@ async def test_invalid_document_id(client: AsyncClient):
@pytest.mark.asyncio
async def test_retrieve_chunks(client: AsyncClient):
"""Test retrieving document chunks"""
upload_string = "The quick brown fox jumps over the lazy dog"
# First ingest a document to search
doc_id = await test_ingest_text_document(
client, content="The quick brown fox jumps over the lazy dog"
)
doc_id = await test_ingest_text_document(client, content=upload_string)
headers = create_auth_header()
# Sleep to allow time for document to be indexed
await asyncio.sleep(1)
response = await client.post(
"/retrieve/chunks",
json={"query": "jumping fox", "k": 1, "min_score": 0.0},
json={
"query": "jumping fox",
"k": 1,
"min_score": 0.0,
"filters": {"external_id": doc_id}, # Add filter for specific document
},
headers=headers,
)
assert response.status_code == 200
results = list(response.json())
assert len(results) == 1
assert len(results) > 0
assert results[0]["score"] > 0.5
assert results[0]["document_id"] == doc_id
assert results[0]["content"] == upload_string
@pytest.mark.asyncio
@ -324,7 +330,10 @@ async def test_retrieve_docs(client: AsyncClient):
headers = create_auth_header()
response = await client.post(
"/retrieve/docs",
json={"query": "Headaches, dehydration", "filters": {"test": True}},
json={
"query": "Headaches, dehydration",
"filters": {"test": True, "external_id": doc_id}, # Add filter for specific document
},
headers=headers,
)
@ -398,3 +407,287 @@ async def test_invalid_completion_params(client: AsyncClient):
headers=headers,
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_retrieve_chunks_default_reranking(client: AsyncClient):
"""Test retrieving chunks with default reranking behavior"""
# First ingest some test documents
_ = await test_ingest_text_document(
client, "The quick brown fox jumps over the lazy dog. This is a test document."
)
_ = await test_ingest_text_document(
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
)
headers = create_auth_header()
response = await client.post(
"/retrieve/chunks",
json={
"query": "What does the fox do?",
"k": 2,
"min_score": 0.0,
# Not specifying use_reranking - should use default from config
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
# Verify chunks are ordered by score
scores = [chunk["score"] for chunk in chunks]
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
@pytest.mark.asyncio
async def test_retrieve_chunks_explicit_reranking(client: AsyncClient):
"""Test retrieving chunks with explicitly enabled reranking"""
# First ingest some test documents
_ = await test_ingest_text_document(
client, "The quick brown fox jumps over the lazy dog. This is a test document."
)
_ = await test_ingest_text_document(
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
)
headers = create_auth_header()
response = await client.post(
"/retrieve/chunks",
json={
"query": "What does the fox do?",
"k": 2,
"min_score": 0.0,
"use_reranking": True,
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
# Verify chunks are ordered by score
scores = [chunk["score"] for chunk in chunks]
assert all(scores[i] >= scores[i + 1] for i in range(len(scores) - 1))
@pytest.mark.asyncio
async def test_retrieve_chunks_disabled_reranking(client: AsyncClient):
"""Test retrieving chunks with explicitly disabled reranking"""
# First ingest some test documents
await test_ingest_text_document(
client, "The quick brown fox jumps over the lazy dog. This is a test document."
)
await test_ingest_text_document(
client, "The lazy dog sleeps while the quick brown fox runs. Another test document."
)
headers = create_auth_header()
response = await client.post(
"/retrieve/chunks",
json={
"query": "What does the fox do?",
"k": 2,
"min_score": 0.0,
"use_reranking": False,
},
headers=headers,
)
assert response.status_code == 200
chunks = response.json()
assert len(chunks) > 0
@pytest.mark.asyncio
async def test_reranking_affects_results(client: AsyncClient):
"""Test that reranking actually changes the order of results"""
# First ingest documents with clearly different semantic relevance
await test_ingest_text_document(
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
)
await test_ingest_text_document(
client, "Paris is a city in France. It has many famous landmarks and museums."
)
await test_ingest_text_document(
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
)
headers = create_auth_header()
# Get results without reranking
response_no_rerank = await client.post(
"/retrieve/chunks",
json={
"query": "Tell me about the capital city of France",
"k": 3,
"min_score": 0.0,
"use_reranking": False,
},
headers=headers,
)
# Get results with reranking
response_with_rerank = await client.post(
"/retrieve/chunks",
json={
"query": "Tell me about the capital city of France",
"k": 3,
"min_score": 0.0,
"use_reranking": True,
},
headers=headers,
)
assert response_no_rerank.status_code == 200
assert response_with_rerank.status_code == 200
chunks_no_rerank = response_no_rerank.json()
chunks_with_rerank = response_with_rerank.json()
# Verify we got results in both cases
assert len(chunks_no_rerank) > 0
assert len(chunks_with_rerank) > 0
# The order or scores should be different between reranked and non-reranked results
# This test might be a bit flaky depending on the exact scoring, but it should work most of the time
# given our carefully crafted test data
scores_no_rerank = [c["score"] for c in chunks_no_rerank]
scores_with_rerank = [c["score"] for c in chunks_with_rerank]
assert scores_no_rerank != scores_with_rerank, "Reranking should affect the scores"
@pytest.mark.asyncio
async def test_retrieve_docs_with_reranking(client: AsyncClient):
"""Test document retrieval with reranking options"""
# First ingest documents with clearly different semantic relevance
await test_ingest_text_document(
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
)
await test_ingest_text_document(
client, "Paris is a city in France. It has many famous landmarks and museums."
)
await test_ingest_text_document(
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
)
headers = create_auth_header()
# Test with default reranking (from config)
response_default = await client.post(
"/retrieve/docs",
json={
"query": "Tell me about the capital city of France",
"k": 3,
"min_score": 0.0,
},
headers=headers,
)
assert response_default.status_code == 200
docs_default = response_default.json()
assert len(docs_default) > 0
# Test with explicit reranking enabled
response_rerank = await client.post(
"/retrieve/docs",
json={
"query": "Tell me about the capital city of France",
"k": 3,
"min_score": 0.0,
"use_reranking": True,
},
headers=headers,
)
assert response_rerank.status_code == 200
docs_rerank = response_rerank.json()
assert len(docs_rerank) > 0
# Test with reranking disabled
response_no_rerank = await client.post(
"/retrieve/docs",
json={
"query": "Tell me about the capital city of France",
"k": 3,
"min_score": 0.0,
"use_reranking": False,
},
headers=headers,
)
assert response_no_rerank.status_code == 200
docs_no_rerank = response_no_rerank.json()
assert len(docs_no_rerank) > 0
# Verify that reranking affects the order
scores_rerank = [doc["score"] for doc in docs_rerank]
scores_no_rerank = [doc["score"] for doc in docs_no_rerank]
assert scores_rerank != scores_no_rerank, "Reranking should affect document scores"
@pytest.mark.asyncio
async def test_query_with_reranking(client: AsyncClient):
"""Test query completion with reranking options"""
# First ingest documents with clearly different semantic relevance
await test_ingest_text_document(
client, "The capital of France is Paris. The city is known for the Eiffel Tower."
)
await test_ingest_text_document(
client, "Paris is a city in France. It has many famous landmarks and museums."
)
await test_ingest_text_document(
client, "Paris Hilton is a celebrity and businesswoman. She has nothing to do with France."
)
headers = create_auth_header()
# Test with default reranking (from config)
response_default = await client.post(
"/query",
json={
"query": "What is the capital of France?",
"k": 3,
"min_score": 0.0,
"max_tokens": 50,
},
headers=headers,
)
assert response_default.status_code == 200
completion_default = response_default.json()
assert "completion" in completion_default
# Test with explicit reranking enabled
response_rerank = await client.post(
"/query",
json={
"query": "What is the capital of France?",
"k": 3,
"min_score": 0.0,
"max_tokens": 50,
"use_reranking": True,
},
headers=headers,
)
assert response_rerank.status_code == 200
completion_rerank = response_rerank.json()
assert "completion" in completion_rerank
# Test with reranking disabled
response_no_rerank = await client.post(
"/query",
json={
"query": "What is the capital of France?",
"k": 3,
"min_score": 0.0,
"max_tokens": 50,
"use_reranking": False,
},
headers=headers,
)
assert response_no_rerank.status_code == 200
completion_no_rerank = response_no_rerank.json()
assert "completion" in completion_no_rerank
# The actual responses might be different due to different chunk ordering,
# but all should mention Paris as the capital
assert "Paris" in completion_default["completion"]
assert "Paris" in completion_rerank["completion"]
assert "Paris" in completion_no_rerank["completion"]

View File

@ -0,0 +1,166 @@
import pytest
from typing import List
from core.models.chunk import DocumentChunk
from core.reranker.bge_reranker import BGEReranker
from core.config import get_settings
@pytest.fixture
def sample_chunks() -> List[DocumentChunk]:
return [
DocumentChunk(
document_id="1",
content="The quick brown fox jumps over the lazy dog",
embedding=[0.1] * 10,
chunk_number=1,
score=0.5,
),
DocumentChunk(
document_id="2",
content="Python is a popular programming language",
embedding=[0.2] * 10,
chunk_number=1,
score=0.7,
),
DocumentChunk(
document_id="3",
content="Machine learning models help analyze data",
embedding=[0.3] * 10,
chunk_number=1,
score=0.3,
),
]
@pytest.fixture
def reranker():
"""Fixture to create and reuse a BGE reranker instance"""
settings = get_settings()
return BGEReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
use_fp16=settings.RERANKER_USE_FP16,
query_max_length=settings.RERANKER_QUERY_MAX_LENGTH,
passage_max_length=settings.RERANKER_PASSAGE_MAX_LENGTH,
)
@pytest.mark.asyncio
async def test_reranker_relevance(reranker, sample_chunks):
"""Test that reranker improves relevance for programming-related query"""
print("\n=== Testing Reranker Relevance ===")
query = "What is Python programming language?"
# Get reranked results
reranked_chunks = await reranker.rerank(query, sample_chunks)
print(f"\nQuery: {query}")
for i, chunk in enumerate(reranked_chunks):
print(f"{i+1}. Score: {chunk.score:.3f} - {chunk.content}")
# The most relevant chunks should be about Python
assert "Python" in reranked_chunks[0].content
assert reranked_chunks[0].score > reranked_chunks[-1].score
# Check that irrelevant content (fox/dog) is ranked lower
fox_chunk_idx = next(i for i, c in enumerate(reranked_chunks) if "fox" in c.content.lower())
assert fox_chunk_idx > 0 # Should not be first
@pytest.mark.asyncio
async def test_reranker_score_distribution(reranker, sample_chunks):
"""Test that reranker produces reasonable score distribution"""
print("\n=== Testing Score Distribution ===")
query = "Tell me about machine learning and data science"
# Get reranked results
reranked_chunks = await reranker.rerank(query, sample_chunks)
print(f"\nQuery: {query}")
for i, chunk in enumerate(reranked_chunks):
print(f"{i+1}. Score: {chunk.score:.3f} - {chunk.content}")
# Check score properties
scores = [c.score for c in reranked_chunks]
assert all(0 <= s <= 1 for s in scores) # Scores should be between 0 and 1
assert len(set(scores)) > 1 # Should have different scores (not all same)
# Verify ordering
assert scores == sorted(scores, reverse=True) # Should be in descending order
# Most relevant chunk should be about ML/data science
top_chunk = reranked_chunks[0]
assert any(term in top_chunk.content.lower() for term in ["machine learning", "data science"])
@pytest.mark.asyncio
async def test_reranker_batch_scoring(reranker):
"""Test that reranker can handle multiple queries/passages efficiently"""
print("\n=== Testing Batch Scoring ===")
texts = [
"Python is a programming language",
"Machine learning is a field of AI",
"The quick brown fox jumps",
"Data science uses statistical methods",
]
queries = ["What is Python?", "Explain artificial intelligence", "Tell me about data analysis"]
# Test multiple queries against multiple texts
for query in queries:
scores = await reranker.compute_score(query, texts)
print(f"\nQuery: {query}")
for text, score in zip(texts, scores):
print(f"Score: {score:.3f} - {text}")
assert len(scores) == len(texts)
assert all(isinstance(s, float) for s in scores)
assert all(0 <= s <= 1 for s in scores)
@pytest.mark.asyncio
async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
"""Test reranker behavior with empty or edge case inputs"""
print("\n=== Testing Edge Cases ===")
# Empty chunks list
result = await reranker.rerank("test query", [])
assert result == []
print("Empty chunks test passed")
# Single chunk
single_chunk = DocumentChunk(
document_id="1",
content="Test content",
embedding=[0.1] * 768,
chunk_number=1,
score=0.5,
)
result = await reranker.rerank("test query", [single_chunk])
assert len(result) == 1
assert isinstance(result[0].score, float)
print(f"Single chunk test passed - Score: {result[0].score:.3f}")
# Empty query
result = await reranker.rerank("", sample_chunks)
assert len(result) == len(sample_chunks)
print("Empty query test passed")
@pytest.mark.asyncio
async def test_reranker_consistency(reranker, sample_chunks):
"""Test that reranker produces consistent results for same input"""
print("\n=== Testing Consistency ===")
query = "What is Python programming?"
# Run reranking multiple times
results1 = await reranker.rerank(query, sample_chunks)
results2 = await reranker.rerank(query, sample_chunks)
# Scores should be the same across runs
scores1 = [c.score for c in results1]
scores2 = [c.score for c in results2]
print("\nScores from first run:", [f"{s:.3f}" for s in scores1])
print("Scores from second run:", [f"{s:.3f}" for s in scores2])
assert scores1 == scores2
# Order should be preserved
assert [c.document_id for c in results1] == [c.document_id for c in results2]
print("Order consistency test passed")

View File

@ -1,10 +1,13 @@
accelerate==1.2.1
aiofiles==24.1.0
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.6.0
anthropic==0.42.0
antlr4-python3-runtime==4.9.3
anyio==4.3.0
appnope==0.1.4
asgiref==3.8.1
assemblyai==0.36.0
asttokens==2.4.1
attrs==23.2.0
@ -18,6 +21,7 @@ botocore==1.34.103
botocore-stubs==1.34.150
build==1.2.2.post1
cachetools==5.3.3
cbor==1.0.0
certifi==2024.2.2
cffi==1.17.0
cfgv==3.4.0
@ -30,7 +34,7 @@ contourpy==1.2.1
cryptography==43.0.0
cycler==0.12.1
dataclasses-json==0.6.7
datasets==2.21.0
datasets==2.19.0
debugpy==1.8.5
decorator==5.1.1
deepdiff==7.0.1
@ -51,10 +55,11 @@ fastapi-cli==0.0.2
ffmpeg-python==0.2.0
filelock==3.15.4
filetype==1.2.0
FlagEmbedding==1.3.3
flatbuffers==24.3.25
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
fsspec==2024.3.1
future==1.0.0
google-api-core==2.19.1
google-auth==2.29.0
@ -67,14 +72,18 @@ h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.24.5
huggingface-hub==0.27.0
humanfriendly==10.0
identify==2.6.3
idna==3.7
ijson==3.3.0
importlib_metadata==8.5.0
iniconfig==2.0.0
inscriptis==2.5.0
iopath==0.1.10
ipykernel==6.29.5
ipython==8.26.0
ir_datasets==0.5.9
jaraco.classes==3.4.0
jaraco.context==6.0.1
jaraco.functools==4.1.0
@ -99,9 +108,11 @@ langchain-text-splitters==0.2.2
langchain-unstructured==0.1.1
langdetect==1.0.9
langsmith==0.1.98
lap==0.5.12
layoutparser==0.3.4
llvmlite==0.43.0
lxml==5.2.2
lz4==4.3.3
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
@ -131,6 +142,18 @@ onnxruntime==1.18.1
openai==1.43.0
opencv-python==4.10.0.84
openpyxl==3.1.5
opentelemetry-api==1.29.0
opentelemetry-exporter-otlp==1.29.0
opentelemetry-exporter-otlp-proto-common==1.29.0
opentelemetry-exporter-otlp-proto-grpc==1.29.0
opentelemetry-exporter-otlp-proto-http==1.29.0
opentelemetry-instrumentation==0.50b0
opentelemetry-instrumentation-asgi==0.50b0
opentelemetry-instrumentation-fastapi==0.50b0
opentelemetry-proto==1.29.0
opentelemetry-sdk==1.29.0
opentelemetry-semantic-conventions==0.50b0
opentelemetry-util-http==0.50b0
ordered-set==4.1.0
orjson==3.10.3
packaging==24.0
@ -141,6 +164,7 @@ pathspec==0.12.1
pdf2image==1.17.0
pdfminer.six==20231228
pdfplumber==0.11.3
peft==0.14.0
pexpect==4.9.0
pikepdf==9.1.1
pillow==10.4.0
@ -156,7 +180,9 @@ protobuf==5.27.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
py-cpuinfo==9.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools==2.0.8
@ -184,6 +210,7 @@ python-magic==0.4.27
python-multipart==0.0.9
python-oxmsg==0.0.1
python-pptx==0.6.23
pytubefix==8.8.5
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.2.0
@ -198,7 +225,11 @@ rich==13.7.1
rsa==4.9
s3transfer==0.10.1
safetensors==0.4.4
scikit-learn==1.6.0
scipy==1.14.0
seaborn==0.13.2
sentence-transformers==3.3.1
sentencepiece==0.2.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
@ -209,6 +240,7 @@ starlette==0.37.2
sympy==1.13.2
tabulate==0.9.0
tenacity==8.5.0
threadpoolctl==3.5.0
tiktoken==0.7.0
timm==1.0.8
tokenizers==0.19.1
@ -218,7 +250,8 @@ torchvision==0.17.2
tornado==6.4.1
tqdm==4.66.4
traitlets==5.14.3
transformers==4.44.0
transformers==4.44.2
trec-car-tools==2.6
twine==6.0.1
typer==0.12.3
types-awscrt==0.21.2
@ -227,6 +260,9 @@ typing-inspect==0.9.0
typing_extensions==4.11.0
tzdata==2024.1
ujson==5.9.0
ultralytics==8.3.55
ultralytics-thop==2.0.13
unlzw3==0.2.3
unstructured==0.15.0
unstructured-client==0.24.1
unstructured-inference==0.7.36
@ -235,6 +271,8 @@ urllib3==2.2.1
uvicorn==0.29.0
uvloop==0.19.0
virtualenv==20.28.0
warc3-wet==0.2.5
warc3-wet-clueweb09==0.2.5
watchfiles==0.21.0
wcwidth==0.2.13
websockets==12.0
@ -243,7 +281,5 @@ xlrd==2.0.1
XlsxWriter==3.2.0
xxhash==3.4.1
yarl==1.9.4
opentelemetry-api>=1.21.0
opentelemetry-sdk>=1.21.0
opentelemetry-instrumentation-fastapi>=0.42b0
opentelemetry-exporter-otlp>=1.21.0
zipp==3.21.0
zlib-state==0.1.9