hotfix for document ranking in ColPali (#64)

* hotfix for document ranking in ColPali

* add re-ranking clause
This commit is contained in:
Arnav Agrawal 2025-03-30 00:26:05 -07:00 committed by GitHub
parent 2f8fb50797
commit a19ff3cc5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ from io import BytesIO
from typing import Dict, Any, List, Optional
from fastapi import UploadFile
from datetime import datetime, UTC
import torch
from core.models.chunk import Chunk, DocumentChunk
from core.models.documents import (
Document,
@ -14,6 +14,7 @@ from core.models.documents import (
)
from ..models.auth import AuthContext
from ..models.graph import Graph
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
from core.services.graph_service import GraphService
from core.database.base_database import BaseDatabase
from core.storage.base_storage import BaseStorage
@ -108,13 +109,15 @@ class DocumentService:
return []
logger.info(f"Found {len(doc_ids)} authorized documents")
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
should_rerank = should_rerank and (not search_multi) # colpali has a different re-ranking method
# Search chunks with vector similarity
chunks = await self.vector_store.query_similar(
query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids
)
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
chunks_multivector = (
await self.colpali_vector_store.query_similar(
query_embedding_multivector, k=k, doc_ids=doc_ids
@ -123,9 +126,7 @@ class DocumentService:
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
if use_colpali:
logger.debug(
f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali"
)
logger.debug(f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali")
# Rerank chunks using the reranker if enabled and available
if chunks and should_rerank and self.reranker is not None:
@ -134,13 +135,51 @@ class DocumentService:
chunks = chunks[:k]
logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
chunks = chunks_multivector + chunks
chunks = await self._combine_multi_and_regular_chunks(query, chunks, chunks_multivector)
# Create and return chunk results
results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results")
return results
async def _combine_multi_and_regular_chunks(self, query: str, chunks: List[DocumentChunk], chunks_multivector: List[DocumentChunk]):
# use colpali as a reranker to get the same level of similarity score for both the chunks as well as the multi-vector chunks
# TODO: Note that the chunks only need to be rescored in case they weren't ingested with colpali-enabled as true.
# In the other case, we know that chunks_multivector can just come ahead of the regular chunks (since we already
# considered the regular chunks when performing the original similarity search). there is scope for optimization here
# by filtering for only the chunks which weren't ingested via colpali...
if len(chunks_multivector) == 0:
return chunks
# TODO: this is duct tape, fix it properly later
model_name = "vidore/colSmol-256M"
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available() else "cpu"
)
model = ColIdefics3.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device, # "cuda:0", # or "mps" if on Apple Silicon
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps"
).eval()
processor = ColIdefics3Processor.from_pretrained(model_name)
# new_chunks = [Chunk(chunk.content, chunk.metadata) for chunk in chunks]
batch_chunks = processor.process_queries([chunk.content for chunk in chunks]).to(device)
query_rep = processor.process_queries([query]).to(device)
multi_vec_representations = model(**batch_chunks)
query_rep = model(**query_rep)
scores = processor.score_multi_vector(query_rep, multi_vec_representations)
for chunk, score in zip(chunks, scores[0]):
chunk.score = score
full_chunks = chunks + chunks_multivector
full_chunks.sort(key=lambda x: x.score, reverse=True)
return full_chunks
async def retrieve_docs(
self,
query: str,