mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
hotfix for document ranking in ColPali (#64)
* hotfix for document ranking in ColPali * add re-ranking clause
This commit is contained in:
parent
2f8fb50797
commit
a19ff3cc5a
@ -3,7 +3,7 @@ from io import BytesIO
|
|||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
|
import torch
|
||||||
from core.models.chunk import Chunk, DocumentChunk
|
from core.models.chunk import Chunk, DocumentChunk
|
||||||
from core.models.documents import (
|
from core.models.documents import (
|
||||||
Document,
|
Document,
|
||||||
@ -14,6 +14,7 @@ from core.models.documents import (
|
|||||||
)
|
)
|
||||||
from ..models.auth import AuthContext
|
from ..models.auth import AuthContext
|
||||||
from ..models.graph import Graph
|
from ..models.graph import Graph
|
||||||
|
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
||||||
from core.services.graph_service import GraphService
|
from core.services.graph_service import GraphService
|
||||||
from core.database.base_database import BaseDatabase
|
from core.database.base_database import BaseDatabase
|
||||||
from core.storage.base_storage import BaseStorage
|
from core.storage.base_storage import BaseStorage
|
||||||
@ -108,13 +109,15 @@ class DocumentService:
|
|||||||
return []
|
return []
|
||||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||||
|
|
||||||
|
|
||||||
|
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
|
||||||
|
should_rerank = should_rerank and (not search_multi) # colpali has a different re-ranking method
|
||||||
|
|
||||||
# Search chunks with vector similarity
|
# Search chunks with vector similarity
|
||||||
chunks = await self.vector_store.query_similar(
|
chunks = await self.vector_store.query_similar(
|
||||||
query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids
|
query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None
|
|
||||||
|
|
||||||
chunks_multivector = (
|
chunks_multivector = (
|
||||||
await self.colpali_vector_store.query_similar(
|
await self.colpali_vector_store.query_similar(
|
||||||
query_embedding_multivector, k=k, doc_ids=doc_ids
|
query_embedding_multivector, k=k, doc_ids=doc_ids
|
||||||
@ -123,9 +126,7 @@ class DocumentService:
|
|||||||
|
|
||||||
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
|
logger.debug(f"Found {len(chunks)} similar chunks via regular embedding")
|
||||||
if use_colpali:
|
if use_colpali:
|
||||||
logger.debug(
|
logger.debug(f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali")
|
||||||
f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rerank chunks using the reranker if enabled and available
|
# Rerank chunks using the reranker if enabled and available
|
||||||
if chunks and should_rerank and self.reranker is not None:
|
if chunks and should_rerank and self.reranker is not None:
|
||||||
@ -134,13 +135,51 @@ class DocumentService:
|
|||||||
chunks = chunks[:k]
|
chunks = chunks[:k]
|
||||||
logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
|
logger.debug(f"Reranked {k*10} chunks and selected the top {k}")
|
||||||
|
|
||||||
chunks = chunks_multivector + chunks
|
chunks = await self._combine_multi_and_regular_chunks(query, chunks, chunks_multivector)
|
||||||
|
|
||||||
# Create and return chunk results
|
# Create and return chunk results
|
||||||
results = await self._create_chunk_results(auth, chunks)
|
results = await self._create_chunk_results(auth, chunks)
|
||||||
logger.info(f"Returning {len(results)} chunk results")
|
logger.info(f"Returning {len(results)} chunk results")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def _combine_multi_and_regular_chunks(self, query: str, chunks: List[DocumentChunk], chunks_multivector: List[DocumentChunk]):
|
||||||
|
# use colpali as a reranker to get the same level of similarity score for both the chunks as well as the multi-vector chunks
|
||||||
|
# TODO: Note that the chunks only need to be rescored in case they weren't ingested with colpali-enabled as true.
|
||||||
|
# In the other case, we know that chunks_multivector can just come ahead of the regular chunks (since we already
|
||||||
|
# considered the regular chunks when performing the original similarity search). there is scope for optimization here
|
||||||
|
# by filtering for only the chunks which weren't ingested via colpali...
|
||||||
|
if len(chunks_multivector) == 0:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# TODO: this is duct tape, fix it properly later
|
||||||
|
|
||||||
|
model_name = "vidore/colSmol-256M"
|
||||||
|
device = (
|
||||||
|
"mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = ColIdefics3.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device, # "cuda:0", # or "mps" if on Apple Silicon
|
||||||
|
attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps"
|
||||||
|
).eval()
|
||||||
|
processor = ColIdefics3Processor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# new_chunks = [Chunk(chunk.content, chunk.metadata) for chunk in chunks]
|
||||||
|
batch_chunks = processor.process_queries([chunk.content for chunk in chunks]).to(device)
|
||||||
|
query_rep = processor.process_queries([query]).to(device)
|
||||||
|
multi_vec_representations = model(**batch_chunks)
|
||||||
|
query_rep = model(**query_rep)
|
||||||
|
scores = processor.score_multi_vector(query_rep, multi_vec_representations)
|
||||||
|
for chunk, score in zip(chunks, scores[0]):
|
||||||
|
chunk.score = score
|
||||||
|
full_chunks = chunks + chunks_multivector
|
||||||
|
full_chunks.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return full_chunks
|
||||||
|
|
||||||
async def retrieve_docs(
|
async def retrieve_docs(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user