diff --git a/core/services/document_service.py b/core/services/document_service.py index b65ba32..0ed63d2 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -3,7 +3,7 @@ from io import BytesIO from typing import Dict, Any, List, Optional from fastapi import UploadFile from datetime import datetime, UTC - +import torch from core.models.chunk import Chunk, DocumentChunk from core.models.documents import ( Document, @@ -14,6 +14,7 @@ from core.models.documents import ( ) from ..models.auth import AuthContext from ..models.graph import Graph +from colpali_engine.models import ColIdefics3, ColIdefics3Processor from core.services.graph_service import GraphService from core.database.base_database import BaseDatabase from core.storage.base_storage import BaseStorage @@ -108,13 +109,15 @@ class DocumentService: return [] logger.info(f"Found {len(doc_ids)} authorized documents") + + search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None + should_rerank = should_rerank and (not search_multi) # colpali has a different re-ranking method + # Search chunks with vector similarity chunks = await self.vector_store.query_similar( query_embedding_regular, k=10 * k if should_rerank else k, doc_ids=doc_ids ) - search_multi = use_colpali and self.colpali_vector_store and query_embedding_multivector is not None - chunks_multivector = ( await self.colpali_vector_store.query_similar( query_embedding_multivector, k=k, doc_ids=doc_ids @@ -123,9 +126,7 @@ class DocumentService: logger.debug(f"Found {len(chunks)} similar chunks via regular embedding") if use_colpali: - logger.debug( - f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali" - ) + logger.debug(f"Found {len(chunks_multivector)} similar chunks via multivector embedding since we are also using colpali") # Rerank chunks using the reranker if enabled and available if chunks and should_rerank and self.reranker is not None: @@ -134,13 +135,51 @@ class DocumentService: chunks = chunks[:k] logger.debug(f"Reranked {k*10} chunks and selected the top {k}") - chunks = chunks_multivector + chunks + chunks = await self._combine_multi_and_regular_chunks(query, chunks, chunks_multivector) # Create and return chunk results results = await self._create_chunk_results(auth, chunks) logger.info(f"Returning {len(results)} chunk results") return results + async def _combine_multi_and_regular_chunks(self, query: str, chunks: List[DocumentChunk], chunks_multivector: List[DocumentChunk]): + # use colpali as a reranker to get the same level of similarity score for both the chunks as well as the multi-vector chunks + # TODO: Note that the chunks only need to be rescored in case they weren't ingested with colpali-enabled as true. + # In the other case, we know that chunks_multivector can just come ahead of the regular chunks (since we already + # considered the regular chunks when performing the original similarity search). there is scope for optimization here + # by filtering for only the chunks which weren't ingested via colpali... + if len(chunks_multivector) == 0: + return chunks + + # TODO: this is duct tape, fix it properly later + + model_name = "vidore/colSmol-256M" + device = ( + "mps" + if torch.backends.mps.is_available() + else "cuda" if torch.cuda.is_available() else "cpu" + ) + + model = ColIdefics3.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device, # "cuda:0", # or "mps" if on Apple Silicon + attn_implementation="eager", # "flash_attention_2" if is_flash_attn_2_available() else None, # or "eager" if "mps" + ).eval() + processor = ColIdefics3Processor.from_pretrained(model_name) + + # new_chunks = [Chunk(chunk.content, chunk.metadata) for chunk in chunks] + batch_chunks = processor.process_queries([chunk.content for chunk in chunks]).to(device) + query_rep = processor.process_queries([query]).to(device) + multi_vec_representations = model(**batch_chunks) + query_rep = model(**query_rep) + scores = processor.score_multi_vector(query_rep, multi_vec_representations) + for chunk, score in zip(chunks, scores[0]): + chunk.score = score + full_chunks = chunks + chunks_multivector + full_chunks.sort(key=lambda x: x.score, reverse=True) + return full_chunks + async def retrieve_docs( self, query: str,