hotfix for document ranking in ColPali (#64)

* hotfix for document ranking in ColPali

* add re-ranking clause
This commit is contained in:
Arnav Agrawal 2025-03-30 00:26:05 -07:00 committed by GitHub
parent 2f8fb50797
commit a19ff3cc5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ from io import BytesIO
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from fastapi import UploadFile from fastapi import UploadFile
from datetime import datetime, UTC from datetime import datetime, UTC
import torch
from core.models.chunk import Chunk, DocumentChunk from core.models.chunk import Chunk, DocumentChunk
from core.models.documents import ( from core.models.documents import (
Document, Document,
@ -14,6 +14,7 @@ from core.models.documents import (
) )
from ..models.auth import AuthContext from ..models.auth import AuthContext
from ..models.graph import Graph from ..models.graph import Graph
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
from core.services.graph_service import GraphService from core.services.graph_service import GraphService
from core.database.base_database import BaseDatabase from core.database.base_database import BaseDatabase
from core.storage.base_storage import BaseStorage from core.storage.base_storage import BaseStorage
@ -108,13 +109,15 @@ class DocumentService:
return [] return []
logger.info(f"Found {len(doc_ids)} authorized documents") logger.info(f"Found {len(doc_ids)} authorized documents")
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
should_rerank = should_rerank and (not search_multi) # colpali has a different re-ranking method
# Search chunks with vector similarity # Search chunks with vector similarity
chunks = await self.vector_store.query_similar( chunks = await self.vector_store.query_similar(
query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids
) )
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
chunks_multivector = ( chunks_multivector = (
await self.colpali_vector_store.query_similar( await self.colpali_vector_store.query_similar(
query_embedding_multivector, k=k, doc_ids=doc_ids query_embedding_multivector, k=k, doc_ids=doc_ids
@ -123,9 +126,7 @@ class DocumentService:
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding") logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
if use_colpali: if use_colpali:
logger.debug( logger.debug(f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali")
f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali"
)
# Rerank chunks using the reranker if enabled and available # Rerank chunks using the reranker if enabled and available
if chunks and should_rerank and self.reranker is not None: if chunks and should_rerank and self.reranker is not None:
@ -134,13 +135,51 @@ class DocumentService:
chunks = chunks[:k] chunks = chunks[:k]
logger.debug(f"Reranked {k*10} chunks and selected the top {k}") logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
chunks = chunks_multivector + chunks chunks = await self._combine_multi_and_regular_chunks(query, chunks, chunks_multivector)
# Create and return chunk results # Create and return chunk results
results = await self._create_chunk_results(auth, chunks) results = await self._create_chunk_results(auth, chunks)
logger.info(f"Returning {len(results)} chunk results") logger.info(f"Returning {len(results)} chunk results")
return results return results
async def _combine_multi_and_regular_chunks(self, query: str, chunks: List[DocumentChunk], chunks_multivector: List[DocumentChunk]):
# use colpali as a reranker to get the same level of similarity score for both the chunks as well as the multi-vector chunks
# TODO: Note that the chunks only need to be rescored in case they weren't ingested with colpali-enabled as true.
# In the other case, we know that chunks_multivector can just come ahead of the regular chunks (since we already
# considered the regular chunks when performing the original similarity search). there is scope for optimization here
# by filtering for only the chunks which weren't ingested via colpali...
if len(chunks_multivector) == 0:
return chunks
# TODO: this is duct tape, fix it properly later
model_name = "vidore/colSmol-256M"
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available() else "cpu"
)
model = ColIdefics3.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device, # "cuda:0", # or "mps" if on Apple Silicon
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps"
).eval()
processor = ColIdefics3Processor.from_pretrained(model_name)
# new_chunks = [Chunk(chunk.content, chunk.metadata) for chunk in chunks]
batch_chunks = processor.process_queries([chunk.content for chunk in chunks]).to(device)
query_rep = processor.process_queries([query]).to(device)
multi_vec_representations = model(**batch_chunks)
query_rep = model(**query_rep)
scores = processor.score_multi_vector(query_rep, multi_vec_representations)
for chunk, score in zip(chunks, scores[0]):
chunk.score = score
full_chunks = chunks + chunks_multivector
full_chunks.sort(key=lambda x: x.score, reverse=True)
return full_chunks
async def retrieve_docs( async def retrieve_docs(
self, self,
query: str, query: str,