mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
from typing import List, Optional, Union
|
|
|
|
from FlagEmbedding import FlagAutoReranker
|
|
|
|
from core.models.chunk import DocumentChunk
|
|
from core.reranker.base_reranker import BaseReranker
|
|
|
|
|
|
class FlagReranker(BaseReranker):
|
|
"""Reranker implementation using FlagEmbedding"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "BAAI/bge-reranker-v2-gemma",
|
|
query_max_length: int = 256,
|
|
passage_max_length: int = 512,
|
|
use_fp16: bool = True,
|
|
device: Optional[str] = None,
|
|
):
|
|
"""Initialize flag reranker"""
|
|
devices = [device] if device else None
|
|
self.reranker = FlagAutoReranker.from_finetuned(
|
|
model_name_or_path=model_name,
|
|
query_max_length=query_max_length,
|
|
passage_max_length=passage_max_length,
|
|
use_fp16=use_fp16,
|
|
devices=devices,
|
|
)
|
|
|
|
async def rerank(
|
|
self,
|
|
query: str,
|
|
chunks: List[DocumentChunk],
|
|
) -> List[DocumentChunk]:
|
|
"""Rerank chunks based on their relevance to the query"""
|
|
if not chunks:
|
|
return []
|
|
|
|
# Get scores for all chunks
|
|
passages = [chunk.content for chunk in chunks]
|
|
scores = await self.compute_score(query, passages)
|
|
|
|
# Update scores and sort chunks
|
|
for chunk, score in zip(chunks, scores):
|
|
chunk.score = float(score)
|
|
|
|
return sorted(chunks, key=lambda x: x.score, reverse=True)
|
|
|
|
async def compute_score(
|
|
self,
|
|
query: str,
|
|
text: Union[str, List[str]],
|
|
) -> Union[float, List[float]]:
|
|
"""Compute relevance scores between query and text"""
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
scores = self.reranker.compute_score([[query, t] for t in text], normalize=True)
|
|
return scores[0] if len(scores) == 1 else scores
|
|
else:
|
|
return self.reranker.compute_score([[query, t] for t in text], normalize=True)
|