mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
hotfix for document ranking in ColPali (#64)
* hotfix for document ranking in ColPali * add re-ranking clause
This commit is contained in:
parent
2f8fb50797
commit
a19ff3cc5a
@ -3,7 +3,7 @@ from io import BytesIO
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import UploadFile
|
||||
from datetime import datetime, UTC
|
||||
|
||||
import torch
|
||||
from core.models.chunk import Chunk, DocumentChunk
|
||||
from core.models.documents import (
|
||||
Document,
|
||||
@ -14,6 +14,7 @@ from core.models.documents import (
|
||||
)
|
||||
from ..models.auth import AuthContext
|
||||
from ..models.graph import Graph
|
||||
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
||||
from core.services.graph_service import GraphService
|
||||
from core.database.base_database import BaseDatabase
|
||||
from core.storage.base_storage import BaseStorage
|
||||
@ -108,13 +109,15 @@ class DocumentService:
|
||||
return []
|
||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||
|
||||
|
||||
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
|
||||
should_rerank = should_rerank and (not search_multi) # colpali has a different re-ranking method
|
||||
|
||||
# Search chunks with vector similarity
|
||||
chunks = await self.vector_store.query_similar(
|
||||
query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids
|
||||
)
|
||||
|
||||
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
|
||||
|
||||
chunks_multivector = (
|
||||
await self.colpali_vector_store.query_similar(
|
||||
query_embedding_multivector, k=k, doc_ids=doc_ids
|
||||
@ -123,9 +126,7 @@ class DocumentService:
|
||||
|
||||
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
|
||||
if use_colpali:
|
||||
logger.debug(
|
||||
f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali"
|
||||
)
|
||||
logger.debug(f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali")
|
||||
|
||||
# Rerank chunks using the reranker if enabled and available
|
||||
if chunks and should_rerank and self.reranker is not None:
|
||||
@ -134,13 +135,51 @@ class DocumentService:
|
||||
chunks = chunks[:k]
|
||||
logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
|
||||
|
||||
chunks = chunks_multivector + chunks
|
||||
chunks = await self._combine_multi_and_regular_chunks(query, chunks, chunks_multivector)
|
||||
|
||||
# Create and return chunk results
|
||||
results = await self._create_chunk_results(auth, chunks)
|
||||
logger.info(f"Returning {len(results)} chunk results")
|
||||
return results
|
||||
|
||||
async def _combine_multi_and_regular_chunks(self, query: str, chunks: List[DocumentChunk], chunks_multivector: List[DocumentChunk]):
|
||||
# use colpali as a reranker to get the same level of similarity score for both the chunks as well as the multi-vector chunks
|
||||
# TODO: Note that the chunks only need to be rescored in case they weren't ingested with colpali-enabled as true.
|
||||
# In the other case, we know that chunks_multivector can just come ahead of the regular chunks (since we already
|
||||
# considered the regular chunks when performing the original similarity search). there is scope for optimization here
|
||||
# by filtering for only the chunks which weren't ingested via colpali...
|
||||
if len(chunks_multivector) == 0:
|
||||
return chunks
|
||||
|
||||
# TODO: this is duct tape, fix it properly later
|
||||
|
||||
model_name = "vidore/colSmol-256M"
|
||||
device = (
|
||||
"mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
model = ColIdefics3.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=device, # "cuda:0", # or "mps" if on Apple Silicon
|
||||
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps"
|
||||
).eval()
|
||||
processor = ColIdefics3Processor.from_pretrained(model_name)
|
||||
|
||||
# new_chunks = [Chunk(chunk.content, chunk.metadata) for chunk in chunks]
|
||||
batch_chunks = processor.process_queries([chunk.content for chunk in chunks]).to(device)
|
||||
query_rep = processor.process_queries([query]).to(device)
|
||||
multi_vec_representations = model(**batch_chunks)
|
||||
query_rep = model(**query_rep)
|
||||
scores = processor.score_multi_vector(query_rep, multi_vec_representations)
|
||||
for chunk, score in zip(chunks, scores[0]):
|
||||
chunk.score = score
|
||||
full_chunks = chunks + chunks_multivector
|
||||
full_chunks.sort(key=lambda x: x.score, reverse=True)
|
||||
return full_chunks
|
||||
|
||||
async def retrieve_docs(
|
||||
self,
|
||||
query: str,
|
||||
|
Loading…
x
Reference in New Issue
Block a user