2024-12-04 20:26:14 -05:00
|
|
|
import base64
|
2024-12-26 08:52:25 -05:00
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from fastapi import UploadFile
|
2024-11-24 14:29:25 -05:00
|
|
|
|
2024-12-26 08:52:25 -05:00
|
|
|
from core.models.request import IngestTextRequest
|
2024-12-30 11:58:53 -05:00
|
|
|
from core.models.chunk import Chunk, DocumentChunk
|
2024-12-28 19:41:05 +05:30
|
|
|
from core.models.documents import (
|
2024-12-26 11:34:24 -05:00
|
|
|
Document,
|
|
|
|
ChunkResult,
|
|
|
|
DocumentContent,
|
|
|
|
DocumentResult,
|
|
|
|
)
|
2024-12-26 08:52:25 -05:00
|
|
|
from ..models.auth import AuthContext
|
2024-11-22 18:56:22 -05:00
|
|
|
from core.database.base_database import BaseDatabase
|
|
|
|
from core.storage.base_storage import BaseStorage
|
|
|
|
from core.vector_store.base_vector_store import BaseVectorStore
|
2024-12-27 11:19:07 +05:30
|
|
|
from core.embedding.base_embedding_model import BaseEmbeddingModel
|
2024-12-26 08:52:25 -05:00
|
|
|
from core.parser.base_parser import BaseParser
|
|
|
|
from core.completion.base_completion import BaseCompletionModel
|
|
|
|
from core.completion.base_completion import CompletionRequest, CompletionResponse
|
|
|
|
import logging
|
2025-01-02 03:42:47 -05:00
|
|
|
from core.reranker.base_reranker import BaseReranker
|
|
|
|
from core.config import get_settings
|
2025-01-29 10:19:28 +05:30
|
|
|
from core.cache.base_cache import BaseCache
|
|
|
|
from core.cache.base_cache_factory import BaseCacheFactory
|
2024-11-22 18:56:22 -05:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class DocumentService:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
database: BaseDatabase,
|
|
|
|
vector_store: BaseVectorStore,
|
|
|
|
storage: BaseStorage,
|
|
|
|
parser: BaseParser,
|
2024-12-26 08:52:25 -05:00
|
|
|
embedding_model: BaseEmbeddingModel,
|
2024-12-26 11:34:24 -05:00
|
|
|
completion_model: BaseCompletionModel,
|
2025-01-29 10:19:28 +05:30
|
|
|
cache_factory: BaseCacheFactory,
|
2025-01-09 15:47:25 +05:30
|
|
|
reranker: Optional[BaseReranker] = None,
|
2024-11-22 18:56:22 -05:00
|
|
|
):
|
|
|
|
self.db = database
|
|
|
|
self.vector_store = vector_store
|
|
|
|
self.storage = storage
|
|
|
|
self.parser = parser
|
|
|
|
self.embedding_model = embedding_model
|
2024-12-26 08:52:25 -05:00
|
|
|
self.completion_model = completion_model
|
2025-01-02 03:42:47 -05:00
|
|
|
self.reranker = reranker
|
2025-01-29 10:19:28 +05:30
|
|
|
self.cache_factory = cache_factory
|
|
|
|
|
|
|
|
# Cache-related data structures
|
|
|
|
# Maps cache name to active cache object
|
|
|
|
self.active_caches: Dict[str, BaseCache] = {}
|
2024-12-26 08:52:25 -05:00
|
|
|
|
|
|
|
async def retrieve_chunks(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
auth: AuthContext,
|
|
|
|
filters: Optional[Dict[str, Any]] = None,
|
2025-01-02 03:42:47 -05:00
|
|
|
k: int = 5,
|
2024-12-26 11:34:24 -05:00
|
|
|
min_score: float = 0.0,
|
2025-01-02 03:42:47 -05:00
|
|
|
use_reranking: Optional[bool] = None,
|
2024-12-26 08:52:25 -05:00
|
|
|
) -> List[ChunkResult]:
|
|
|
|
"""Retrieve relevant chunks."""
|
2025-01-02 03:42:47 -05:00
|
|
|
settings = get_settings()
|
|
|
|
should_rerank = use_reranking if use_reranking is not None else settings.USE_RERANKING
|
|
|
|
|
2024-12-26 08:52:25 -05:00
|
|
|
# Get embedding for query
|
|
|
|
query_embedding = await self.embedding_model.embed_for_query(query)
|
|
|
|
logger.info("Generated query embedding")
|
|
|
|
|
|
|
|
# Find authorized documents
|
|
|
|
doc_ids = await self.db.find_authorized_and_filtered_documents(auth, filters)
|
|
|
|
if not doc_ids:
|
|
|
|
logger.info("No authorized documents found")
|
|
|
|
return []
|
|
|
|
logger.info(f"Found {len(doc_ids)} authorized documents")
|
|
|
|
|
|
|
|
# Search chunks with vector similarity
|
2025-01-02 03:42:47 -05:00
|
|
|
chunks = await self.vector_store.query_similar(
|
|
|
|
query_embedding, k=10 * k if should_rerank else k, doc_ids=doc_ids
|
|
|
|
)
|
2024-12-26 08:52:25 -05:00
|
|
|
logger.info(f"Found {len(chunks)} similar chunks")
|
|
|
|
|
2025-01-09 15:47:25 +05:30
|
|
|
# Rerank chunks using the reranker if enabled and available
|
|
|
|
if chunks and should_rerank and self.reranker is not None:
|
2025-01-02 03:42:47 -05:00
|
|
|
chunks = await self.reranker.rerank(query, chunks)
|
|
|
|
chunks.sort(key=lambda x: x.score, reverse=True)
|
|
|
|
chunks = chunks[:k]
|
|
|
|
logger.info(f"Reranked {k*10} chunks and selected the top {k}")
|
|
|
|
|
2024-12-26 08:52:25 -05:00
|
|
|
# Create and return chunk results
|
|
|
|
results = await self._create_chunk_results(auth, chunks)
|
|
|
|
logger.info(f"Returning {len(results)} chunk results")
|
|
|
|
return results
|
|
|
|
|
|
|
|
async def retrieve_docs(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
auth: AuthContext,
|
|
|
|
filters: Optional[Dict[str, Any]] = None,
|
2025-01-02 03:42:47 -05:00
|
|
|
k: int = 5,
|
2024-12-26 11:34:24 -05:00
|
|
|
min_score: float = 0.0,
|
2025-01-02 03:42:47 -05:00
|
|
|
use_reranking: Optional[bool] = None,
|
2024-12-26 08:52:25 -05:00
|
|
|
) -> List[DocumentResult]:
|
|
|
|
"""Retrieve relevant documents."""
|
|
|
|
# Get chunks first
|
2025-01-02 03:42:47 -05:00
|
|
|
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
|
2024-12-26 08:52:25 -05:00
|
|
|
# Convert to document results
|
|
|
|
results = await self._create_document_results(auth, chunks)
|
2024-12-29 12:45:12 +05:30
|
|
|
documents = list(results.values())
|
|
|
|
logger.info(f"Returning {len(documents)} document results")
|
|
|
|
return documents
|
2024-12-26 08:52:25 -05:00
|
|
|
|
|
|
|
async def query(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
auth: AuthContext,
|
|
|
|
filters: Optional[Dict[str, Any]] = None,
|
2024-12-31 06:58:34 -05:00
|
|
|
k: int = 20, # from contextual embedding paper
|
2024-12-26 08:52:25 -05:00
|
|
|
min_score: float = 0.0,
|
|
|
|
max_tokens: Optional[int] = None,
|
2024-12-26 11:34:24 -05:00
|
|
|
temperature: Optional[float] = None,
|
2025-01-02 03:42:47 -05:00
|
|
|
use_reranking: Optional[bool] = None,
|
2024-12-26 08:52:25 -05:00
|
|
|
) -> CompletionResponse:
|
|
|
|
"""Generate completion using relevant chunks as context."""
|
|
|
|
# Get relevant chunks
|
2025-01-02 03:42:47 -05:00
|
|
|
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score, use_reranking)
|
2024-12-29 12:45:12 +05:30
|
|
|
documents = await self._create_document_results(auth, chunks)
|
|
|
|
|
2024-12-29 12:48:41 +05:30
|
|
|
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
|
2024-12-26 08:52:25 -05:00
|
|
|
|
|
|
|
# Generate completion
|
|
|
|
request = CompletionRequest(
|
|
|
|
query=query,
|
|
|
|
context_chunks=chunk_contents,
|
|
|
|
max_tokens=max_tokens,
|
2024-12-26 11:34:24 -05:00
|
|
|
temperature=temperature,
|
2024-12-26 08:52:25 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
response = await self.completion_model.complete(request)
|
|
|
|
return response
|
2024-11-22 18:56:22 -05:00
|
|
|
|
2024-12-29 12:48:41 +05:30
|
|
|
async def ingest_text(self, request: IngestTextRequest, auth: AuthContext) -> Document:
|
2024-11-24 14:29:25 -05:00
|
|
|
"""Ingest a text document."""
|
2024-11-28 19:09:40 -05:00
|
|
|
if "write" not in auth.permissions:
|
2024-12-15 14:31:25 -05:00
|
|
|
logger.error(f"User {auth.entity_id} does not have write permission")
|
2024-11-28 19:09:40 -05:00
|
|
|
raise PermissionError("User does not have write permission")
|
2024-11-24 14:29:25 -05:00
|
|
|
|
2024-12-15 14:31:25 -05:00
|
|
|
# 1. Create document record
|
|
|
|
doc = Document(
|
|
|
|
content_type="text/plain",
|
|
|
|
metadata=request.metadata,
|
2024-12-26 11:34:24 -05:00
|
|
|
owner={"type": auth.entity_type, "id": auth.entity_id},
|
2024-12-15 14:31:25 -05:00
|
|
|
access_control={
|
|
|
|
"readers": [auth.entity_id],
|
|
|
|
"writers": [auth.entity_id],
|
2024-12-26 11:34:24 -05:00
|
|
|
"admins": [auth.entity_id],
|
|
|
|
},
|
2024-12-15 14:31:25 -05:00
|
|
|
)
|
|
|
|
logger.info(f"Created text document record with ID {doc.external_id}")
|
2025-01-29 10:19:28 +05:30
|
|
|
doc.system_metadata["content"] = request.content
|
2024-12-15 14:31:25 -05:00
|
|
|
|
|
|
|
# 2. Parse content into chunks
|
|
|
|
chunks = await self.parser.split_text(request.content)
|
|
|
|
if not chunks:
|
|
|
|
raise ValueError("No content chunks extracted from text")
|
|
|
|
logger.info(f"Split text into {len(chunks)} chunks")
|
|
|
|
|
|
|
|
# 3. Generate embeddings for chunks
|
|
|
|
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
|
|
|
|
logger.info(f"Generated {len(embeddings)} embeddings")
|
|
|
|
|
|
|
|
# 4. Create and store chunk objects
|
2024-12-28 19:41:05 +05:30
|
|
|
chunk_objects = self._create_chunk_objects(doc.external_id, chunks, embeddings)
|
2024-12-15 14:31:25 -05:00
|
|
|
logger.info(f"Created {len(chunk_objects)} chunk objects")
|
|
|
|
|
|
|
|
# 5. Store everything
|
|
|
|
await self._store_chunks_and_doc(chunk_objects, doc)
|
|
|
|
logger.info(f"Successfully stored text document {doc.external_id}")
|
|
|
|
|
|
|
|
return doc
|
2024-11-24 14:29:25 -05:00
|
|
|
|
|
|
|
async def ingest_file(
|
2024-12-26 11:34:24 -05:00
|
|
|
self, file: UploadFile, metadata: Dict[str, Any], auth: AuthContext
|
2024-11-24 14:29:25 -05:00
|
|
|
) -> Document:
|
|
|
|
"""Ingest a file document."""
|
2024-11-28 19:09:40 -05:00
|
|
|
if "write" not in auth.permissions:
|
|
|
|
raise PermissionError("User does not have write permission")
|
2024-11-24 14:29:25 -05:00
|
|
|
|
2024-12-29 11:10:51 +05:30
|
|
|
file_content = await file.read()
|
|
|
|
additional_metadata, chunks = await self.parser.parse_file(
|
2025-01-01 09:18:23 -05:00
|
|
|
file_content, file.content_type or "", file.filename
|
2024-12-29 11:10:51 +05:30
|
|
|
)
|
|
|
|
|
2024-12-15 14:31:25 -05:00
|
|
|
doc = Document(
|
2024-12-28 19:41:05 +05:30
|
|
|
content_type=file.content_type or "",
|
2024-12-15 14:31:25 -05:00
|
|
|
filename=file.filename,
|
|
|
|
metadata=metadata,
|
2024-12-26 11:34:24 -05:00
|
|
|
owner={"type": auth.entity_type, "id": auth.entity_id},
|
2024-12-15 14:31:25 -05:00
|
|
|
access_control={
|
|
|
|
"readers": [auth.entity_id],
|
|
|
|
"writers": [auth.entity_id],
|
2024-12-26 11:34:24 -05:00
|
|
|
"admins": [auth.entity_id],
|
|
|
|
},
|
2024-12-29 11:10:51 +05:30
|
|
|
additional_metadata=additional_metadata,
|
2024-12-15 14:31:25 -05:00
|
|
|
)
|
2025-01-29 10:19:28 +05:30
|
|
|
doc.system_metadata["content"] = "\n".join(chunk.content for chunk in chunks)
|
2024-12-15 14:31:25 -05:00
|
|
|
logger.info(f"Created file document record with ID {doc.external_id}")
|
|
|
|
|
|
|
|
storage_info = await self.storage.upload_from_base64(
|
2024-12-26 11:34:24 -05:00
|
|
|
base64.b64encode(file_content).decode(), doc.external_id, file.content_type
|
2024-12-15 14:31:25 -05:00
|
|
|
)
|
2024-12-26 11:34:24 -05:00
|
|
|
doc.storage_info = {"bucket": storage_info[0], "key": storage_info[1]}
|
2024-12-29 12:48:41 +05:30
|
|
|
logger.info(f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`")
|
2024-12-15 14:31:25 -05:00
|
|
|
|
|
|
|
if not chunks:
|
|
|
|
raise ValueError("No content chunks extracted from file")
|
|
|
|
logger.info(f"Parsed file into {len(chunks)} chunks")
|
|
|
|
|
|
|
|
# 4. Generate embeddings for chunks
|
|
|
|
embeddings = await self.embedding_model.embed_for_ingestion(chunks)
|
|
|
|
logger.info(f"Generated {len(embeddings)} embeddings")
|
|
|
|
|
|
|
|
# 5. Create and store chunk objects
|
2024-12-28 19:41:05 +05:30
|
|
|
chunk_objects = self._create_chunk_objects(doc.external_id, chunks, embeddings)
|
2024-12-15 14:31:25 -05:00
|
|
|
logger.info(f"Created {len(chunk_objects)} chunk objects")
|
|
|
|
|
|
|
|
# 6. Store everything
|
|
|
|
doc.chunk_ids = await self._store_chunks_and_doc(chunk_objects, doc)
|
|
|
|
logger.info(f"Successfully stored file document {doc.external_id}")
|
|
|
|
|
|
|
|
return doc
|
2024-11-22 18:56:22 -05:00
|
|
|
|
2024-11-24 14:29:25 -05:00
|
|
|
def _create_chunk_objects(
|
|
|
|
self,
|
|
|
|
doc_id: str,
|
2024-12-28 19:41:05 +05:30
|
|
|
chunks: List[Chunk],
|
2024-11-24 14:29:25 -05:00
|
|
|
embeddings: List[List[float]],
|
|
|
|
) -> List[DocumentChunk]:
|
|
|
|
"""Helper to create chunk objects"""
|
|
|
|
return [
|
2024-12-28 19:41:05 +05:30
|
|
|
c.to_document_chunk(chunk_number=i, embedding=embedding, document_id=doc_id)
|
|
|
|
for i, (embedding, c) in enumerate(zip(embeddings, chunks))
|
2024-11-24 14:29:25 -05:00
|
|
|
]
|
|
|
|
|
|
|
|
async def _store_chunks_and_doc(
|
2024-12-26 11:34:24 -05:00
|
|
|
self, chunk_objects: List[DocumentChunk], doc: Document
|
2024-11-28 19:09:40 -05:00
|
|
|
) -> List[str]:
|
2024-11-24 14:29:25 -05:00
|
|
|
"""Helper to store chunks and document"""
|
|
|
|
# Store chunks in vector store
|
2024-11-28 19:09:40 -05:00
|
|
|
success, result = await self.vector_store.store_embeddings(chunk_objects)
|
|
|
|
if not success:
|
2024-11-24 14:29:25 -05:00
|
|
|
raise Exception("Failed to store chunk embeddings")
|
|
|
|
logger.debug("Stored chunk embeddings in vector store")
|
|
|
|
|
2024-12-17 21:40:38 -05:00
|
|
|
doc.chunk_ids = result
|
2024-11-24 14:29:25 -05:00
|
|
|
# Store document metadata
|
|
|
|
if not await self.db.store_document(doc):
|
|
|
|
raise Exception("Failed to store document metadata")
|
|
|
|
logger.debug("Stored document metadata in database")
|
2024-12-15 14:31:25 -05:00
|
|
|
logger.debug(f"Chunk IDs stored: {result}")
|
|
|
|
return result
|
2024-11-28 19:09:40 -05:00
|
|
|
|
2024-11-24 14:29:25 -05:00
|
|
|
async def _create_chunk_results(
|
2024-12-26 11:34:24 -05:00
|
|
|
self, auth: AuthContext, chunks: List[DocumentChunk]
|
2024-11-24 14:29:25 -05:00
|
|
|
) -> List[ChunkResult]:
|
2024-11-22 18:56:22 -05:00
|
|
|
"""Create ChunkResult objects with document metadata."""
|
|
|
|
results = []
|
|
|
|
for chunk in chunks:
|
|
|
|
# Get document metadata
|
|
|
|
doc = await self.db.get_document(chunk.document_id, auth)
|
|
|
|
if not doc:
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.warning(f"Document {chunk.document_id} not found")
|
2024-11-22 18:56:22 -05:00
|
|
|
continue
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.debug(f"Retrieved metadata for document {chunk.document_id}")
|
2024-11-22 18:56:22 -05:00
|
|
|
|
|
|
|
# Generate download URL if needed
|
|
|
|
download_url = None
|
|
|
|
if doc.storage_info:
|
|
|
|
download_url = await self.storage.get_download_url(
|
2024-12-26 11:34:24 -05:00
|
|
|
doc.storage_info["bucket"], doc.storage_info["key"]
|
2024-11-22 18:56:22 -05:00
|
|
|
)
|
2024-12-26 11:34:24 -05:00
|
|
|
logger.debug(f"Generated download URL for document {chunk.document_id}")
|
|
|
|
|
|
|
|
results.append(
|
|
|
|
ChunkResult(
|
|
|
|
content=chunk.content,
|
|
|
|
score=chunk.score,
|
|
|
|
document_id=chunk.document_id,
|
|
|
|
chunk_number=chunk.chunk_number,
|
|
|
|
metadata=doc.metadata,
|
|
|
|
content_type=doc.content_type,
|
|
|
|
filename=doc.filename,
|
|
|
|
download_url=download_url,
|
2024-11-24 14:29:25 -05:00
|
|
|
)
|
2024-12-26 11:34:24 -05:00
|
|
|
)
|
2024-11-22 18:56:22 -05:00
|
|
|
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.info(f"Created {len(results)} chunk results")
|
2024-11-22 18:56:22 -05:00
|
|
|
return results
|
|
|
|
|
2024-11-24 14:29:25 -05:00
|
|
|
async def _create_document_results(
|
2024-12-28 19:41:05 +05:30
|
|
|
self, auth: AuthContext, chunks: List[ChunkResult]
|
2024-12-29 12:45:12 +05:30
|
|
|
) -> Dict[str, DocumentResult]:
|
2024-11-22 18:56:22 -05:00
|
|
|
"""Group chunks by document and create DocumentResult objects."""
|
|
|
|
# Group chunks by document and get highest scoring chunk per doc
|
2024-12-28 19:41:05 +05:30
|
|
|
doc_chunks: Dict[str, ChunkResult] = {}
|
2024-11-22 18:56:22 -05:00
|
|
|
for chunk in chunks:
|
2024-12-26 11:34:24 -05:00
|
|
|
if (
|
|
|
|
chunk.document_id not in doc_chunks
|
|
|
|
or chunk.score > doc_chunks[chunk.document_id].score
|
|
|
|
):
|
2024-11-22 18:56:22 -05:00
|
|
|
doc_chunks[chunk.document_id] = chunk
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
|
2024-12-02 20:03:35 -05:00
|
|
|
logger.info(f"Document chunks: {doc_chunks}")
|
2024-12-29 12:45:12 +05:30
|
|
|
results = {}
|
2024-11-22 18:56:22 -05:00
|
|
|
for doc_id, chunk in doc_chunks.items():
|
|
|
|
# Get document metadata
|
|
|
|
doc = await self.db.get_document(doc_id, auth)
|
|
|
|
if not doc:
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.warning(f"Document {doc_id} not found")
|
2024-11-22 18:56:22 -05:00
|
|
|
continue
|
2024-12-02 20:03:35 -05:00
|
|
|
logger.info(f"Retrieved metadata for document {doc_id}")
|
2024-11-22 18:56:22 -05:00
|
|
|
|
|
|
|
# Create DocumentContent based on content type
|
|
|
|
if doc.content_type == "text/plain":
|
2024-12-29 12:48:41 +05:30
|
|
|
content = DocumentContent(type="string", value=chunk.content, filename=None)
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.debug(f"Created text content for document {doc_id}")
|
2024-11-22 18:56:22 -05:00
|
|
|
else:
|
|
|
|
# Generate download URL for file types
|
|
|
|
download_url = await self.storage.get_download_url(
|
2024-12-26 11:34:24 -05:00
|
|
|
doc.storage_info["bucket"], doc.storage_info["key"]
|
2024-11-22 18:56:22 -05:00
|
|
|
)
|
2024-12-29 12:48:41 +05:30
|
|
|
content = DocumentContent(type="url", value=download_url, filename=doc.filename)
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.debug(f"Created URL content for document {doc_id}")
|
2024-12-29 12:45:12 +05:30
|
|
|
results[doc_id] = DocumentResult(
|
|
|
|
score=chunk.score,
|
|
|
|
document_id=doc_id,
|
|
|
|
metadata=doc.metadata,
|
|
|
|
content=content,
|
|
|
|
additional_metadata=doc.additional_metadata,
|
2024-12-26 11:34:24 -05:00
|
|
|
)
|
2024-11-22 18:56:22 -05:00
|
|
|
|
2024-11-24 14:29:25 -05:00
|
|
|
logger.info(f"Created {len(results)} document results")
|
2024-11-22 18:56:22 -05:00
|
|
|
return results
|
2025-01-29 10:19:28 +05:30
|
|
|
|
|
|
|
async def create_cache(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
model: str,
|
|
|
|
gguf_file: str,
|
|
|
|
docs: List[Document | None],
|
|
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
|
|
) -> Dict[str, str]:
|
|
|
|
"""Create a new cache with specified configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Name of the cache to create
|
|
|
|
model: Name of the model to use
|
|
|
|
gguf_file: Name of the GGUF file to use
|
|
|
|
filters: Optional metadata filters for documents to include
|
|
|
|
docs: Optional list of specific document IDs to include
|
|
|
|
"""
|
|
|
|
# Create cache metadata
|
|
|
|
metadata = {
|
|
|
|
"model": model,
|
|
|
|
"model_file": gguf_file,
|
|
|
|
"filters": filters,
|
|
|
|
"docs": [doc.model_dump_json() for doc in docs],
|
|
|
|
"storage_info": {
|
|
|
|
"bucket": "caches",
|
|
|
|
"key": f"{name}_state.pkl",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
# Store metadata in database
|
|
|
|
success = await self.db.store_cache_metadata(name, metadata)
|
|
|
|
if not success:
|
|
|
|
logger.error(f"Failed to store cache metadata for cache {name}")
|
|
|
|
return {"success": False, "message": f"Failed to store cache metadata for cache {name}"}
|
|
|
|
|
|
|
|
# Create cache instance
|
|
|
|
cache = self.cache_factory.create_new_cache(
|
|
|
|
name=name, model=model, model_file=gguf_file, filters=filters, docs=docs
|
|
|
|
)
|
|
|
|
cache_bytes = cache.saveable_state
|
|
|
|
base64_cache_bytes = base64.b64encode(cache_bytes).decode()
|
|
|
|
bucket, key = await self.storage.upload_from_base64(
|
|
|
|
base64_cache_bytes,
|
|
|
|
key=metadata["storage_info"]["key"],
|
|
|
|
bucket=metadata["storage_info"]["bucket"],
|
|
|
|
)
|
|
|
|
return {
|
|
|
|
"success": True,
|
|
|
|
"message": f"Cache created successfully, state stored in bucket `{bucket}` with key `{key}`",
|
|
|
|
}
|
|
|
|
|
|
|
|
async def load_cache(self, name: str) -> bool:
|
|
|
|
"""Load a cache into memory.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Name of the cache to load
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: Whether the cache exists and was loaded successfully
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# Get cache metadata from database
|
|
|
|
metadata = await self.db.get_cache_metadata(name)
|
|
|
|
if not metadata:
|
|
|
|
logger.error(f"No metadata found for cache {name}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
# Get cache bytes from storage
|
|
|
|
cache_bytes = await self.storage.download_file(
|
|
|
|
metadata["storage_info"]["bucket"], "caches/" + metadata["storage_info"]["key"]
|
|
|
|
)
|
|
|
|
cache_bytes = cache_bytes.read()
|
|
|
|
cache = self.cache_factory.load_cache_from_bytes(
|
|
|
|
name=name, cache_bytes=cache_bytes, metadata=metadata
|
|
|
|
)
|
|
|
|
self.active_caches[name] = cache
|
|
|
|
return {"success": True, "message": "Cache loaded successfully"}
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to load cache {name}: {e}")
|
|
|
|
# raise e
|
|
|
|
return {"success": False, "message": f"Failed to load cache {name}: {e}"}
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
"""Close all resources."""
|
|
|
|
# Close any active caches
|
|
|
|
self.active_caches.clear()
|