From d124e6aa0d59cd72edb1cd9393780c8184998566 Mon Sep 17 00:00:00 2001 From: Arnav Agrawal <88790414+ArnavAgrawal03@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:19:28 +0530 Subject: [PATCH] Add support for cache-augmented-generation (#30) --- core/api.py | 135 ++++++++++++++ core/cache/base_cache.py | 68 +++++++ core/cache/base_cache_factory.py | 64 +++++++ core/cache/hf_cache.py | 285 +++++++++++++++++++++++++++++ core/cache/llama_cache.py | 196 ++++++++++++++++++++ core/cache/llama_cache_factory.py | 15 ++ core/completion/base_completion.py | 19 +- core/database/base_database.py | 25 +++ core/database/mongo_database.py | 43 +++++ core/database/postgres_database.py | 68 +++++++ core/models/completion.py | 19 ++ core/models/documents.py | 8 + core/services/document_service.py | 97 ++++++++++ core/storage/base_storage.py | 4 +- core/storage/local_storage.py | 6 +- core/storage/s3_storage.py | 8 +- core/tests/integration/test_api.py | 10 +- core/tests/unit/test_cache.py | 184 +++++++++++++++++++ core/tests/unit/test_reranker.py | 12 ++ dockerfile | 11 ++ init.sql | 12 +- quick_setup.py | 17 +- requirements.txt | 1 + sdks/python/databridge/__init__.py | 2 +- sdks/python/databridge/async_.py | 90 +++++++++ sdks/python/databridge/sync.py | 91 +++++++++ sdks/python/pyproject.toml | 2 +- shell.py | 48 +++++ 28 files changed, 1509 insertions(+), 31 deletions(-) create mode 100644 core/cache/base_cache.py create mode 100644 core/cache/base_cache_factory.py create mode 100644 core/cache/hf_cache.py create mode 100644 core/cache/llama_cache.py create mode 100644 core/cache/llama_cache_factory.py create mode 100644 core/models/completion.py create mode 100644 core/tests/unit/test_cache.py diff --git a/core/api.py b/core/api.py index 609f123..ccff64b 100644 --- a/core/api.py +++ b/core/api.py @@ -1,5 +1,7 @@ import json from datetime import datetime, UTC, timedelta +from pathlib import Path +import sys from typing import Any, Dict, List, Optional from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile from fastapi.middleware.cors import CORSMiddleware @@ -30,6 +32,7 @@ from core.embedding.openai_embedding_model import OpenAIEmbeddingModel from core.completion.ollama_completion import OllamaCompletionModel from core.parser.contextual_parser import ContextualParser from core.reranker.flag_reranker import FlagReranker +from core.cache.llama_cache_factory import LlamaCacheFactory import tomli # Initialize FastAPI app @@ -214,6 +217,9 @@ if settings.USE_RERANKING: case _: raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}") +# Initialize cache factory +cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH)) + # Initialize document service with configured components document_service = DocumentService( storage=storage, @@ -223,6 +229,7 @@ document_service = DocumentService( completion_model=completion_model, parser=parser, reranker=reranker, + cache_factory=cache_factory, ) @@ -462,6 +469,134 @@ async def get_recent_usage( ] +# Cache endpoints +@app.post("/cache/create") +async def create_cache( + name: str, + model: str, + gguf_file: str, + filters: Optional[Dict[str, Any]] = None, + docs: Optional[List[str]] = None, + auth: AuthContext = Depends(verify_token), +) -> Dict[str, Any]: + """Create a new cache with specified configuration.""" + try: + async with telemetry.track_operation( + operation_type="create_cache", + user_id=auth.entity_id, + metadata={ + "name": name, + "model": model, + "gguf_file": gguf_file, + "filters": filters, + "docs": docs, + }, + ): + filter_docs = set(await document_service.db.get_documents(auth, filters=filters)) + additional_docs = ( + { + await document_service.db.get_document(document_id=doc_id, auth=auth) + for doc_id in docs + } + if docs + else set() + ) + docs_to_add = list(filter_docs.union(additional_docs)) + if not docs_to_add: + raise HTTPException(status_code=400, detail="No documents to add to cache") + response = await document_service.create_cache( + name, model, gguf_file, docs_to_add, filters + ) + return response + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + + +@app.get("/cache/{name}") +async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]: + """Get cache configuration by name.""" + try: + async with telemetry.track_operation( + operation_type="get_cache", + user_id=auth.entity_id, + metadata={"name": name}, + ): + exists = await document_service.load_cache(name) + return {"exists": exists} + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + + +@app.post("/cache/{name}/update") +async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]: + """Update cache with new documents matching its filter.""" + try: + async with telemetry.track_operation( + operation_type="update_cache", + user_id=auth.entity_id, + metadata={"name": name}, + ): + if name not in document_service.active_caches: + exists = await document_service.load_cache(name) + if not exists: + raise HTTPException(status_code=404, detail=f"Cache '{name}' not found") + cache = document_service.active_caches[name] + docs = await document_service.db.get_documents(auth, filters=cache.filters) + docs_to_add = [doc for doc in docs if doc.id not in cache.docs] + return cache.add_docs(docs_to_add) + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + + +@app.post("/cache/{name}/add_docs") +async def add_docs_to_cache( + name: str, docs: List[str], auth: AuthContext = Depends(verify_token) +) -> Dict[str, bool]: + """Add specific documents to the cache.""" + try: + async with telemetry.track_operation( + operation_type="add_docs_to_cache", + user_id=auth.entity_id, + metadata={"name": name, "docs": docs}, + ): + cache = document_service.active_caches[name] + docs_to_add = [ + await document_service.db.get_document(doc_id, auth) + for doc_id in docs + if doc_id not in cache.docs + ] + return cache.add_docs(docs_to_add) + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + + +@app.post("/cache/{name}/query") +async def query_cache( + name: str, + query: str, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + auth: AuthContext = Depends(verify_token), +) -> CompletionResponse: + """Query the cache with a prompt.""" + try: + async with telemetry.track_operation( + operation_type="query_cache", + user_id=auth.entity_id, + metadata={ + "name": name, + "query": query, + "max_tokens": max_tokens, + "temperature": temperature, + }, + ): + cache = document_service.active_caches[name] + print(f"Cache state: {cache.state.n_tokens}", file=sys.stderr) + return cache.query(query) # , max_tokens, temperature) + except PermissionError as e: + raise HTTPException(status_code=403, detail=str(e)) + + @app.post("/local/generate_uri", include_in_schema=True) async def generate_local_uri( name: str = Form("admin"), diff --git a/core/cache/base_cache.py b/core/cache/base_cache.py new file mode 100644 index 0000000..1180c50 --- /dev/null +++ b/core/cache/base_cache.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from core.models.completion import CompletionResponse +from core.models.documents import Document + + +class BaseCache(ABC): + """Base class for cache implementations. + + This class defines the interface for cache implementations that support + document ingestion and cache-augmented querying. + """ + + def __init__( + self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document] + ): + """Initialize the cache with the given parameters. + + Args: + name: Name of the cache instance + model: Model identifier + gguf_file: Path to the GGUF model file + filters: Filters used to create the cache context + docs: Initial documents to ingest into the cache + """ + self.name = name + self.filters = filters + self.docs = [] # List of document IDs that have been ingested + self._initialize(model, gguf_file, docs) + + @abstractmethod + def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None: + """Internal initialization method to be implemented by subclasses.""" + pass + + @abstractmethod + async def add_docs(self, docs: List[Document]) -> bool: + """Add documents to the cache. + + Args: + docs: List of documents to add to the cache + + Returns: + bool: True if documents were successfully added + """ + pass + + @abstractmethod + async def query(self, query: str) -> CompletionResponse: + """Query the cache for relevant documents and generate a response. + + Args: + query: Query string to search for relevant documents + + Returns: + CompletionResponse: Generated response based on cached context + """ + pass + + @property + @abstractmethod + def saveable_state(self) -> bytes: + """Get the saveable state of the cache as bytes. + + Returns: + bytes: Serialized state that can be used to restore the cache + """ + pass diff --git a/core/cache/base_cache_factory.py b/core/cache/base_cache_factory.py new file mode 100644 index 0000000..707a4e1 --- /dev/null +++ b/core/cache/base_cache_factory.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, Any +from .base_cache import BaseCache + + +class BaseCacheFactory(ABC): + """Abstract base factory for creating and loading caches.""" + + def __init__(self, storage_path: Path): + """Initialize the cache factory. + + Args: + storage_path: Base path for storing cache files + """ + self.storage_path = storage_path + self.storage_path.mkdir(parents=True, exist_ok=True) + + @abstractmethod + def create_new_cache( + self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any] + ) -> BaseCache: + """Create a new cache instance. + + Args: + name: Name of the cache + model: Name/type of the model to use + model_file: Path or identifier for the model file + **kwargs: Additional arguments for cache creation + + Returns: + BaseCache: The created cache instance + """ + pass + + @abstractmethod + def load_cache_from_bytes( + self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any] + ) -> BaseCache: + """Load a cache from its serialized bytes. + + Args: + name: Name of the cache + cache_bytes: Serialized cache data + metadata: Cache metadata including model info + **kwargs: Additional arguments for cache loading + + Returns: + BaseCache: The loaded cache instance + """ + pass + + def get_cache_path(self, name: str) -> Path: + """Get the storage path for a cache. + + Args: + name: Name of the cache + + Returns: + Path: Directory path for the cache + """ + path = self.storage_path / name + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/core/cache/hf_cache.py b/core/cache/hf_cache.py new file mode 100644 index 0000000..29789e9 --- /dev/null +++ b/core/cache/hf_cache.py @@ -0,0 +1,285 @@ +# hugging face cache implementation. + +from core.cache.base_cache import BaseCache +from typing import List, Optional, Union +from pathlib import Path +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.cache_utils import DynamicCache +from core.models.completion import CompletionRequest, CompletionResponse + + +class HuggingFaceCache(BaseCache): + """Hugging Face Cache implementation for cache-augmented generation""" + + def __init__( + self, + cache_path: Path, + model_name: str = "distilgpt2", + device: str = "cpu", + default_max_new_tokens: int = 100, + use_fp16: bool = False, + ): + """Initialize the HuggingFace cache. + + Args: + cache_path: Path to store cache files + model_name: Name of the HuggingFace model to use + device: Device to run the model on (e.g. "cpu", "cuda", "mps") + default_max_new_tokens: Default maximum number of new tokens to generate + use_fp16: Whether to use FP16 precision + """ + super().__init__() + self.cache_path = cache_path + self.model_name = model_name + self.device = device + self.default_max_new_tokens = default_max_new_tokens + self.use_fp16 = use_fp16 + + # Initialize tokenizer and model + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Configure model loading based on device + model_kwargs = {"low_cpu_mem_usage": True} + + if device == "cpu": + # For CPU, use standard loading + model_kwargs.update({"torch_dtype": torch.float32}) + self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device) + else: + # For GPU/MPS, use automatic device mapping and optional FP16 + model_kwargs.update( + {"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32} + ) + self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + + self.kv_cache = None + self.origin_len = None + + def get_kv_cache(self, prompt: str) -> DynamicCache: + """Build KV cache from prompt""" + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + cache = DynamicCache() + + with torch.no_grad(): + _ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True) + return cache + + def clean_up_cache(self, cache: DynamicCache, origin_len: int): + """Clean up cache by removing appended tokens""" + for i in range(len(cache.key_cache)): + cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] + cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] + + def generate( + self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None + ) -> torch.Tensor: + """Generate text using the model and cache""" + device = next(self.model.parameters()).device + origin_len = input_ids.shape[-1] + input_ids = input_ids.to(device) + output_ids = input_ids.clone() + next_token = input_ids + + with torch.no_grad(): + for _ in range(max_new_tokens or self.default_max_new_tokens): + out = self.model( + input_ids=next_token, past_key_values=past_key_values, use_cache=True + ) + logits = out.logits[:, -1, :] + token = torch.argmax(logits, dim=-1, keepdim=True) + output_ids = torch.cat([output_ids, token], dim=-1) + past_key_values = out.past_key_values + next_token = token.to(device) + + if ( + self.model.config.eos_token_id is not None + and token.item() == self.model.config.eos_token_id + ): + break + + return output_ids[:, origin_len:] + + async def ingest(self, docs: List[str]) -> bool: + """Ingest documents into cache""" + try: + # Create system prompt with documents + system_prompt = f""" +<|system|> +You are an assistant who provides concise factual answers. +<|user|> +Context: +{' '.join(docs)} +Question: +""".strip() + + # Build the cache + input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device) + self.kv_cache = DynamicCache() + + with torch.no_grad(): + # First run to get the cache shape + outputs = self.model(input_ids=input_ids, use_cache=True) + # Initialize cache with empty tensors of the right shape + n_layers = len(outputs.past_key_values) + batch_size = input_ids.shape[0] + + # Handle different model architectures + + if hasattr(self.model.config, "num_key_value_heads"): + # Models with grouped query attention (GQA) like Llama + n_kv_heads = self.model.config.num_key_value_heads + head_dim = self.model.config.head_dim + elif hasattr(self.model.config, "n_head"): + # GPT-style models + n_kv_heads = self.model.config.n_head + head_dim = self.model.config.n_embd // self.model.config.n_head + elif hasattr(self.model.config, "num_attention_heads"): + # OPT-style models + n_kv_heads = self.model.config.num_attention_heads + head_dim = ( + self.model.config.hidden_size // self.model.config.num_attention_heads + ) + else: + raise ValueError( + f"Unsupported model architecture: {self.model.config.model_type}" + ) + + seq_len = input_ids.shape[1] + + for i in range(n_layers): + key_shape = (batch_size, n_kv_heads, seq_len, head_dim) + value_shape = key_shape + self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device)) + self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device)) + + # Now run with the initialized cache + outputs = self.model( + input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True + ) + # Update cache with actual values + self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values] + self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values] + self.origin_len = self.kv_cache.key_cache[0].shape[-2] + return True + except Exception as e: + print(f"Error ingesting documents: {e}") + return False + + async def update(self, new_doc: str) -> bool: + """Update cache with new document""" + try: + if self.kv_cache is None: + return await self.ingest([new_doc]) + + # Clean up existing cache + self.clean_up_cache(self.kv_cache, self.origin_len) + + # Add new document to cache + input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to( + self.device + ) + + # First run to get the cache shape + outputs = self.model(input_ids=input_ids, use_cache=True) + # Initialize cache with empty tensors of the right shape + n_layers = len(outputs.past_key_values) + batch_size = input_ids.shape[0] + + # Handle different model architectures + if hasattr(self.model.config, "num_key_value_heads"): + # Models with grouped query attention (GQA) like Llama + n_kv_heads = self.model.config.num_key_value_heads + head_dim = self.model.config.head_dim + elif hasattr(self.model.config, "n_head"): + # GPT-style models + n_kv_heads = self.model.config.n_head + head_dim = self.model.config.n_embd // self.model.config.n_head + elif hasattr(self.model.config, "num_attention_heads"): + # OPT-style models + n_kv_heads = self.model.config.num_attention_heads + head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + else: + raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}") + + seq_len = input_ids.shape[1] + + # Create a new cache for the update + new_cache = DynamicCache() + for i in range(n_layers): + key_shape = (batch_size, n_kv_heads, seq_len, head_dim) + value_shape = key_shape + new_cache.key_cache.append(torch.zeros(key_shape, device=self.device)) + new_cache.value_cache.append(torch.zeros(value_shape, device=self.device)) + + # Run with the initialized cache + outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True) + # Update cache with actual values + self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values] + self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values] + return True + except Exception as e: + print(f"Error updating cache: {e}") + return False + + async def complete(self, request: CompletionRequest) -> CompletionResponse: + """Generate completion using cache-augmented generation""" + try: + if self.kv_cache is None: + raise ValueError("Cache not initialized. Please ingest documents first.") + + # Clean up cache + self.clean_up_cache(self.kv_cache, self.origin_len) + + # Generate completion + input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to( + self.device + ) + gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens) + completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) + + # Calculate token usage + usage = { + "prompt_tokens": len(input_ids[0]), + "completion_tokens": len(gen_ids[0]), + "total_tokens": len(input_ids[0]) + len(gen_ids[0]), + } + + return CompletionResponse(completion=completion, usage=usage) + except Exception as e: + print(f"Error generating completion: {e}") + return CompletionResponse( + completion=f"Error: {str(e)}", + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + def save_cache(self) -> Path: + """Save the KV cache to disk""" + if self.kv_cache is None: + raise ValueError("No cache to save") + + cache_dir = self.cache_path / "kv_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Save key and value caches + cache_data = { + "key_cache": self.kv_cache.key_cache, + "value_cache": self.kv_cache.value_cache, + "origin_len": self.origin_len, + } + cache_path = cache_dir / "cache.pt" + torch.save(cache_data, cache_path) + return cache_path + + def load_cache(self, cache_path: Union[str, Path]) -> None: + """Load KV cache from disk""" + cache_path = Path(cache_path) + if not cache_path.exists(): + raise FileNotFoundError(f"Cache file not found at {cache_path}") + + cache_data = torch.load(cache_path, map_location=self.device) + + self.kv_cache = DynamicCache() + self.kv_cache.key_cache = cache_data["key_cache"] + self.kv_cache.value_cache = cache_data["value_cache"] + self.origin_len = cache_data["origin_len"] diff --git a/core/cache/llama_cache.py b/core/cache/llama_cache.py new file mode 100644 index 0000000..243aa1c --- /dev/null +++ b/core/cache/llama_cache.py @@ -0,0 +1,196 @@ +import json +import pickle +import logging +from core.cache.base_cache import BaseCache +from typing import Dict, Any, List +from core.models.completion import CompletionResponse +from core.models.documents import Document +from llama_cpp import Llama + +logger = logging.getLogger(__name__) + +INITIAL_SYSTEM_PROMPT = """<|im_start|>system +You are a helpful AI assistant with access to provided documents. Your role is to: +1. Answer questions accurately based on the documents provided +2. Stay focused on the document content and avoid speculation +3. Admit when you don't have enough information to answer +4. Be clear and concise in your responses +5. Use direct quotes from documents when relevant + +Provided documents: {documents} +<|im_end|> +""".strip() + +ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system +I'm adding some additional documents for your reference: +{documents} + +Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses. +<|im_end|> +""".strip() + +QUERY_PROMPT = """<|im_start|>user +{query} +<|im_end|> +<|im_start|>assistant +""".strip() + + +class LlamaCache(BaseCache): + def __init__( + self, + name: str, + model: str, + gguf_file: str, + filters: Dict[str, Any], + docs: List[Document], + **kwargs, + ): + logger.info(f"Initializing LlamaCache with name={name}, model={model}") + # cache related + self.name = name + self.model = model + self.filters = filters + self.docs = docs + + # llama specific + self.gguf_file = gguf_file + self.n_gpu_layers = kwargs.get("n_gpu_layers", -1) + logger.info(f"Using {self.n_gpu_layers} GPU layers") + + # late init (when we call _initialize) + self.llama = None + self.state = None + self.cached_tokens = 0 + + self._initialize(model, gguf_file, docs) + logger.info("LlamaCache initialization complete") + + def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None: + logger.info(f"Loading Llama model from {model} with file {gguf_file}") + try: + # Set a reasonable default context size (32K tokens) + default_ctx_size = 32768 + + self.llama = Llama.from_pretrained( + repo_id=model, + filename=gguf_file, + n_gpu_layers=self.n_gpu_layers, + n_ctx=default_ctx_size, + verbose=False, # Enable verbose mode for better error reporting + ) + logger.info("Model loaded successfully") + + # Format and tokenize system prompt + documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs) + system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents) + logger.info(f"Built system prompt: {system_prompt[:200]}...") + + try: + tokens = self.llama.tokenize(system_prompt.encode()) + logger.info(f"System prompt tokenized to {len(tokens)} tokens") + + # Process tokens to build KV cache + logger.info("Evaluating system prompt") + self.llama.eval(tokens) + logger.info("Saving initial KV cache state") + self.state = self.llama.save_state() + self.cached_tokens = len(tokens) + logger.info(f"Initial KV cache built with {self.cached_tokens} tokens") + except Exception as e: + logger.error(f"Error during prompt processing: {str(e)}") + raise ValueError(f"Failed to process system prompt: {str(e)}") + + except Exception as e: + logger.error(f"Failed to initialize Llama model: {str(e)}") + raise ValueError(f"Failed to initialize Llama model: {str(e)}") + + def add_docs(self, docs: List[Document]) -> bool: + logger.info(f"Adding {len(docs)} new documents to cache") + documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs) + system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents) + + # Tokenize and process + new_tokens = self.llama.tokenize(system_prompt.encode()) + self.llama.eval(new_tokens) + self.state = self.llama.save_state() + self.cached_tokens += len(new_tokens) + logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}") + return True + + def query(self, query: str) -> CompletionResponse: + # Format query with proper chat template + formatted_query = QUERY_PROMPT.format(query=query) + logger.info(f"Processing query: {formatted_query}") + + # Reset and load cached state + self.llama.reset() + self.llama.load_state(self.state) + logger.info(f"Loaded state with {self.state.n_tokens} tokens") + # print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr) + # Tokenize and process query + query_tokens = self.llama.tokenize(formatted_query.encode()) + self.llama.eval(query_tokens) + logger.info(f"Evaluated query tokens: {query_tokens}") + # print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr) + + # Generate response + output_tokens = [] + for token in self.llama.generate(tokens=[], reset=False): + output_tokens.append(token) + # Stop generation when EOT token is encountered + if token == self.llama.token_eos(): + break + + # Decode and return + completion = self.llama.detokenize(output_tokens).decode() + logger.info(f"Generated completion: {completion}") + + return CompletionResponse( + completion=completion, + usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)}, + ) + + @property + def saveable_state(self) -> bytes: + logger.info("Serializing cache state") + state_bytes = pickle.dumps(self.state) + logger.info(f"Serialized state size: {len(state_bytes)} bytes") + return state_bytes + + @classmethod + def from_bytes( + cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs + ) -> "LlamaCache": + """Load a cache from its serialized state. + + Args: + name: Name of the cache + cache_bytes: Pickled state bytes + metadata: Cache metadata including model info + **kwargs: Additional arguments + + Returns: + LlamaCache: Loaded cache instance + """ + logger.info(f"Loading cache from bytes with name={name}") + logger.info(f"Cache metadata: {metadata}") + # Create new instance with metadata + # logger.info(f"Docs: {metadata['docs']}") + docs = [json.loads(doc) for doc in metadata["docs"]] + # time.sleep(10) + cache = cls( + name=name, + model=metadata["model"], + gguf_file=metadata["model_file"], + filters=metadata["filters"], + docs=[Document(**doc) for doc in docs], + ) + + # Load the saved state + logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes") + cache.state = pickle.loads(cache_bytes) + cache.llama.load_state(cache.state) + logger.info("Cache successfully loaded from bytes") + + return cache diff --git a/core/cache/llama_cache_factory.py b/core/cache/llama_cache_factory.py new file mode 100644 index 0000000..db401d1 --- /dev/null +++ b/core/cache/llama_cache_factory.py @@ -0,0 +1,15 @@ +from core.cache.base_cache_factory import BaseCacheFactory +from core.cache.llama_cache import LlamaCache +from typing import Dict, Any + + +class LlamaCacheFactory(BaseCacheFactory): + def create_new_cache( + self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any] + ) -> LlamaCache: + return LlamaCache(name, model, model_file, **kwargs) + + def load_cache_from_bytes( + self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any] + ) -> LlamaCache: + return LlamaCache.from_bytes(name, cache_bytes, metadata, **kwargs) diff --git a/core/completion/base_completion.py b/core/completion/base_completion.py index bf0ed80..6c2872a 100644 --- a/core/completion/base_completion.py +++ b/core/completion/base_completion.py @@ -1,22 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Dict -from pydantic import BaseModel - - -class CompletionResponse(BaseModel): - """Response from completion generation""" - - completion: str - usage: Dict[str, int] - - -class CompletionRequest(BaseModel): - """Request for completion generation""" - - query: str - context_chunks: List[str] - max_tokens: Optional[int] = 1000 - temperature: Optional[float] = 0.7 +from core.models.completion import CompletionRequest, CompletionResponse class BaseCompletionModel(ABC): diff --git a/core/database/base_database.py b/core/database/base_database.py index c48dc6d..dbeb96f 100644 --- a/core/database/base_database.py +++ b/core/database/base_database.py @@ -72,3 +72,28 @@ class BaseDatabase(ABC): Returns: True if user has required access, False otherwise """ pass + + @abstractmethod + async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool: + """Store metadata for a cache. + + Args: + name: Name of the cache + metadata: Cache metadata including model info and storage location + + Returns: + bool: Whether the operation was successful + """ + pass + + @abstractmethod + async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]: + """Get metadata for a cache. + + Args: + name: Name of the cache + + Returns: + Optional[Dict[str, Any]]: Cache metadata if found, None otherwise + """ + pass diff --git a/core/database/mongo_database.py b/core/database/mongo_database.py index 7d07b6a..3933b3c 100644 --- a/core/database/mongo_database.py +++ b/core/database/mongo_database.py @@ -26,6 +26,7 @@ class MongoDatabase(BaseDatabase): self.client = AsyncIOMotorClient(uri) self.db = self.client[db_name] self.collection = self.db[collection_name] + self.caches = self.db["caches"] # Collection for cache metadata async def initialize(self): """Initialize database indexes.""" @@ -217,3 +218,45 @@ class MongoDatabase(BaseDatabase): for key, value in filters.items(): filter_dict[f"metadata.{key}"] = value return filter_dict + + async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool: + """Store metadata for a cache in MongoDB. + + Args: + name: Name of the cache + metadata: Cache metadata including model info and storage location + + Returns: + bool: Whether the operation was successful + """ + try: + # Add timestamp and ensure name is included + doc = { + "name": name, + "metadata": metadata, + "created_at": datetime.now(UTC), + "updated_at": datetime.now(UTC), + } + + # Upsert the document + result = await self.caches.update_one({"name": name}, {"$set": doc}, upsert=True) + return bool(result.modified_count or result.upserted_id) + except Exception as e: + logger.error(f"Failed to store cache metadata: {e}") + return False + + async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]: + """Get metadata for a cache from MongoDB. + + Args: + name: Name of the cache + + Returns: + Optional[Dict[str, Any]]: Cache metadata if found, None otherwise + """ + try: + doc = await self.caches.find_one({"name": name}) + return doc["metadata"] if doc else None + except Exception as e: + logger.error(f"Failed to get cache metadata: {e}") + return None diff --git a/core/database/postgres_database.py b/core/database/postgres_database.py index f20b740..04a4076 100644 --- a/core/database/postgres_database.py +++ b/core/database/postgres_database.py @@ -1,3 +1,4 @@ +import json from typing import List, Optional, Dict, Any from datetime import datetime, UTC import logging @@ -65,8 +66,24 @@ class PostgresDatabase(BaseDatabase): try: async with self.engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + + # Create caches table if it doesn't exist + await conn.execute( + text( + """ + CREATE TABLE IF NOT EXISTS caches ( + name TEXT PRIMARY KEY, + metadata JSONB NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + ) + logger.info("PostgreSQL tables and indexes created successfully") return True + except Exception as e: logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}") return False @@ -324,3 +341,54 @@ class PostgresDatabase(BaseDatabase): filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'") return " AND ".join(filter_conditions) + + async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool: + """Store metadata for a cache in PostgreSQL. + + Args: + name: Name of the cache + metadata: Cache metadata including model info and storage location + + Returns: + bool: Whether the operation was successful + """ + try: + async with self.async_session() as session: + await session.execute( + text( + """ + INSERT INTO caches (name, metadata, updated_at) + VALUES (:name, :metadata, CURRENT_TIMESTAMP) + ON CONFLICT (name) + DO UPDATE SET + metadata = :metadata, + updated_at = CURRENT_TIMESTAMP + """ + ), + {"name": name, "metadata": json.dumps(metadata)}, + ) + await session.commit() + return True + except Exception as e: + logger.error(f"Failed to store cache metadata: {e}") + return False + + async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]: + """Get metadata for a cache from PostgreSQL. + + Args: + name: Name of the cache + + Returns: + Optional[Dict[str, Any]]: Cache metadata if found, None otherwise + """ + try: + async with self.async_session() as session: + result = await session.execute( + text("SELECT metadata FROM caches WHERE name = :name"), {"name": name} + ) + row = result.first() + return row[0] if row else None + except Exception as e: + logger.error(f"Failed to get cache metadata: {e}") + return None diff --git a/core/models/completion.py b/core/models/completion.py new file mode 100644 index 0000000..167deae --- /dev/null +++ b/core/models/completion.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel +from typing import Dict, List, Optional + + +class CompletionResponse(BaseModel): + """Response from completion generation""" + + completion: str + usage: Dict[str, int] + finish_reason: Optional[str] = None + + +class CompletionRequest(BaseModel): + """Request for completion generation""" + + query: str + context_chunks: List[str] + max_tokens: Optional[int] = 1000 + temperature: Optional[float] = 0.7 diff --git a/core/models/documents.py b/core/models/documents.py index bdc3310..39d68f8 100644 --- a/core/models/documents.py +++ b/core/models/documents.py @@ -40,6 +40,14 @@ class Document(BaseModel): ) chunk_ids: List[str] = Field(default_factory=list) + def __hash__(self): + return hash(self.external_id) + + def __eq__(self, other): + if not isinstance(other, Document): + return False + return self.external_id == other.external_id + class DocumentContent(BaseModel): """Represents either a URL or content string""" diff --git a/core/services/document_service.py b/core/services/document_service.py index 14d9316..7cfe2bf 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -21,6 +21,8 @@ from core.completion.base_completion import CompletionRequest, CompletionRespons import logging from core.reranker.base_reranker import BaseReranker from core.config import get_settings +from core.cache.base_cache import BaseCache +from core.cache.base_cache_factory import BaseCacheFactory logger = logging.getLogger(__name__) @@ -34,6 +36,7 @@ class DocumentService: parser: BaseParser, embedding_model: BaseEmbeddingModel, completion_model: BaseCompletionModel, + cache_factory: BaseCacheFactory, reranker: Optional[BaseReranker] = None, ): self.db = database @@ -43,6 +46,11 @@ class DocumentService: self.embedding_model = embedding_model self.completion_model = completion_model self.reranker = reranker + self.cache_factory = cache_factory + + # Cache-related data structures + # Maps cache name to active cache object + self.active_caches: Dict[str, BaseCache] = {} async def retrieve_chunks( self, @@ -151,6 +159,7 @@ class DocumentService: }, ) logger.info(f"Created text document record with ID {doc.external_id}") + doc.system_metadata["content"] = request.content # 2. Parse content into chunks chunks = await self.parser.split_text(request.content) @@ -196,6 +205,7 @@ class DocumentService: }, additional_metadata=additional_metadata, ) + doc.system_metadata["content"] = "\n".join(chunk.content for chunk in chunks) logger.info(f"Created file document record with ID {doc.external_id}") storage_info = await self.storage.upload_from_base64( @@ -333,3 +343,90 @@ class DocumentService: logger.info(f"Created {len(results)} document results") return results + + async def create_cache( + self, + name: str, + model: str, + gguf_file: str, + docs: List[Document | None], + filters: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Create a new cache with specified configuration. + + Args: + name: Name of the cache to create + model: Name of the model to use + gguf_file: Name of the GGUF file to use + filters: Optional metadata filters for documents to include + docs: Optional list of specific document IDs to include + """ + # Create cache metadata + metadata = { + "model": model, + "model_file": gguf_file, + "filters": filters, + "docs": [doc.model_dump_json() for doc in docs], + "storage_info": { + "bucket": "caches", + "key": f"{name}_state.pkl", + }, + } + + # Store metadata in database + success = await self.db.store_cache_metadata(name, metadata) + if not success: + logger.error(f"Failed to store cache metadata for cache {name}") + return {"success": False, "message": f"Failed to store cache metadata for cache {name}"} + + # Create cache instance + cache = self.cache_factory.create_new_cache( + name=name, model=model, model_file=gguf_file, filters=filters, docs=docs + ) + cache_bytes = cache.saveable_state + base64_cache_bytes = base64.b64encode(cache_bytes).decode() + bucket, key = await self.storage.upload_from_base64( + base64_cache_bytes, + key=metadata["storage_info"]["key"], + bucket=metadata["storage_info"]["bucket"], + ) + return { + "success": True, + "message": f"Cache created successfully, state stored in bucket `{bucket}` with key `{key}`", + } + + async def load_cache(self, name: str) -> bool: + """Load a cache into memory. + + Args: + name: Name of the cache to load + + Returns: + bool: Whether the cache exists and was loaded successfully + """ + try: + # Get cache metadata from database + metadata = await self.db.get_cache_metadata(name) + if not metadata: + logger.error(f"No metadata found for cache {name}") + return False + + # Get cache bytes from storage + cache_bytes = await self.storage.download_file( + metadata["storage_info"]["bucket"], "caches/" + metadata["storage_info"]["key"] + ) + cache_bytes = cache_bytes.read() + cache = self.cache_factory.load_cache_from_bytes( + name=name, cache_bytes=cache_bytes, metadata=metadata + ) + self.active_caches[name] = cache + return {"success": True, "message": "Cache loaded successfully"} + except Exception as e: + logger.error(f"Failed to load cache {name}: {e}") + # raise e + return {"success": False, "message": f"Failed to load cache {name}: {e}"} + + def close(self): + """Close all resources.""" + # Close any active caches + self.active_caches.clear() diff --git a/core/storage/base_storage.py b/core/storage/base_storage.py index b76ee2d..ad9ee1f 100644 --- a/core/storage/base_storage.py +++ b/core/storage/base_storage.py @@ -7,7 +7,7 @@ class BaseStorage(ABC): @abstractmethod async def upload_from_base64( - self, content: str, key: str, content_type: Optional[str] = None + self, content: str, key: str, content_type: Optional[str] = None, bucket: str = "" ) -> Tuple[str, str]: """ Upload base64 encoded content. @@ -16,7 +16,7 @@ class BaseStorage(ABC): content: Base64 encoded content key: Storage key/path content_type: Optional MIME type - + bucket: Optional bucket/folder name Returns: Tuple[str, str]: (bucket/container name, storage key) """ diff --git a/core/storage/local_storage.py b/core/storage/local_storage.py index 48db6b5..6ea7e29 100644 --- a/core/storage/local_storage.py +++ b/core/storage/local_storage.py @@ -19,16 +19,20 @@ class LocalStorage(BaseStorage): return open(file_path, "rb") async def upload_from_base64( - self, base64_content: str, key: str, content_type: Optional[str] = None + self, content: str, key: str, content_type: Optional[str] = None, bucket: str = "" ) -> Tuple[str, str]: + base64_content = content """Upload base64 encoded content to local storage.""" # Decode base64 content file_content = base64.b64decode(base64_content) + key = f"{bucket}/{key}" if bucket else key # Create file path file_path = self.storage_path / key # Write content to file + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.unlink(missing_ok=True) with open(file_path, "wb") as f: f.write(file_content) diff --git a/core/storage/s3_storage.py b/core/storage/s3_storage.py index e0e2790..d193325 100644 --- a/core/storage/s3_storage.py +++ b/core/storage/s3_storage.py @@ -37,6 +37,7 @@ class S3Storage(BaseStorage): file: Union[str, bytes, BinaryIO], key: str, content_type: Optional[str] = None, + bucket: str = "", ) -> Tuple[str, str]: """Upload a file to S3.""" try: @@ -70,15 +71,18 @@ class S3Storage(BaseStorage): raise async def upload_from_base64( - self, content: str, key: str, content_type: Optional[str] = None + self, content: str, key: str, content_type: Optional[str] = None, bucket: str = "" ) -> Tuple[str, str]: """Upload base64 encoded content to S3.""" + key = f"{bucket}/{key}" if bucket else key try: decoded_content = base64.b64decode(content) extension = detect_file_type(content) key = f"{key}{extension}" - return await self.upload_file(file=decoded_content, key=key, content_type=content_type) + return await self.upload_file( + file=decoded_content, key=key, content_type=content_type, bucket=bucket + ) except Exception as e: logger.error(f"Error uploading base64 content to S3: {e}") diff --git a/core/tests/integration/test_api.py b/core/tests/integration/test_api.py index 6a504f4..10ba8e2 100644 --- a/core/tests/integration/test_api.py +++ b/core/tests/integration/test_api.py @@ -238,6 +238,8 @@ async def test_ingest_oversized_content(client: AsyncClient): @pytest.mark.asyncio async def test_auth_missing_header(client: AsyncClient): """Test authentication with missing auth header""" + if get_settings().dev_mode: + pytest.skip("Auth tests skipped in dev mode") response = await client.post("/ingest/text") assert response.status_code == 401 @@ -245,6 +247,8 @@ async def test_auth_missing_header(client: AsyncClient): @pytest.mark.asyncio async def test_auth_invalid_token(client: AsyncClient): """Test authentication with invalid token""" + if get_settings().dev_mode: + pytest.skip("Auth tests skipped in dev mode") headers = {"Authorization": "Bearer invalid_token"} response = await client.post("/ingest/file", headers=headers) assert response.status_code == 401 @@ -253,6 +257,8 @@ async def test_auth_invalid_token(client: AsyncClient): @pytest.mark.asyncio async def test_auth_expired_token(client: AsyncClient): """Test authentication with expired token""" + if get_settings().dev_mode: + pytest.skip("Auth tests skipped in dev mode") headers = create_auth_header(expired=True) response = await client.post("/ingest/text", headers=headers) assert response.status_code == 401 @@ -261,6 +267,8 @@ async def test_auth_expired_token(client: AsyncClient): @pytest.mark.asyncio async def test_auth_insufficient_permissions(client: AsyncClient): """Test authentication with insufficient permissions""" + if get_settings().dev_mode: + pytest.skip("Auth tests skipped in dev mode") headers = create_auth_header(permissions=["read"]) response = await client.post( "/ingest/text", @@ -335,7 +343,7 @@ async def test_retrieve_chunks(client: AsyncClient): assert response.status_code == 200 results = list(response.json()) assert len(results) > 0 - assert results[0]["score"] > 0.5 + assert (not get_settings().USE_RERANKING) or results[0]["score"] > 0.5 assert any(upload_string == result["content"] for result in results) diff --git a/core/tests/unit/test_cache.py b/core/tests/unit/test_cache.py new file mode 100644 index 0000000..eea11fc --- /dev/null +++ b/core/tests/unit/test_cache.py @@ -0,0 +1,184 @@ +import pytest +from core.models.documents import Document +from core.cache.llama_cache import LlamaCache +from core.models.completion import CompletionResponse + +# TEST_MODEL = "QuantFactory/Llama3.2-3B-Enigma-GGUF" +TEST_MODEL = "QuantFactory/Dolphin3.0-Llama3.2-1B-GGUF" +# TEST_MODEL = "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF" +TEST_GGUF_FILE = "*Q4_K_S.gguf" +# TEST_GGUF_FILE = "*Q4_K_M.gguf" + + +def get_test_document(): + """Load the example.txt file as a test document.""" + # test_file = Path(__file__).parent.parent / "assets" / "example.txt" + # with open(test_file, "r") as f: + # content = f.read() + + content = """ + In darkest hours of coding's fierce domain, + Where bugs lurk deep in shadows, hard to find, + Each error message brings fresh waves of pain, + As stack traces drive madness through the mind. + + Through endless loops of print statements we wade, + Debug flags raised like torches in the night, + While segfaults mock each careful plan we made, + And race conditions laugh at our plight. + + O rubber duck, my silent debugging friend, + Your plastic gaze holds wisdom yet untold, + As line by line we trace paths without end, + Seeking that elusive bug of gold. + + Yet hope remains while coffee still flows strong, + Through debugging hell, we'll debug right from wrong. + """.strip() + + return Document( + external_id="alice_ch1", + owner={"id": "test_user", "name": "Test User"}, + content_type="text/plain", + system_metadata={ + "content": content, + "title": "Alice in Wonderland - Chapter 1", + "source": "test_document", + }, + ) + + +@pytest.fixture +def llama_cache(): + """Create a LlamaCache instance with the test document.""" + doc = get_test_document() + cache = LlamaCache( + name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc] + ) + return cache + + +def test_basic_rag_capabilities(llama_cache): + """Test that the cache can answer basic questions about the document content.""" + # Test question about whether ingestion is actually happening + response = llama_cache.query( + "Summarize the content of the document. Please respond in a single sentence. Summary: " + ) + assert isinstance(response, CompletionResponse) + # assert "alice" in response.completion.lower() + + # # Test question about a specific detail + # response = llama_cache.query( + # "What did Alice see the White Rabbit do with its watch? Please respond in a single sentence. Answer: " + # ) + # assert isinstance(response, CompletionResponse) + # # assert "waistcoat-pocket" in response.completion.lower() or "looked at it" in response.completion.lower() + + # # Test question about character description + # response = llama_cache.query( + # "How did Alice's size change during the story? Please respond in a single sentence. Answer: " + # ) + # assert isinstance(response, CompletionResponse) + # # assert any(phrase in response.completion.lower() for phrase in ["grew larger", "grew smaller", "nine feet", "telescope"]) + + # # Test question about plot elements + # response = llama_cache.query( + # "What was written on the bottle Alice found? Please respond in a single sentence. Answer: " + # ) + # assert isinstance(response, CompletionResponse) + # # assert "drink me" in response.completion.lower() + + +# def test_cache_memory_persistence(llama_cache): +# """Test that the cache maintains context across multiple queries.""" + +# # First query to establish context +# llama_cache.query( +# "What was Alice doing before she saw the White Rabbit? Please respond in a single sentence. Answer: " +# ) + +# # Follow-up query that requires remembering previous context +# response = llama_cache.query( +# "What book was her sister reading? Please respond in a single sentence. Answer: " +# ) +# assert isinstance(response, CompletionResponse) +# # assert "no pictures" in response.completion.lower() or "conversations" in response.completion.lower() + + +def test_adding_new_documents(llama_cache): + """Test that the cache can incorporate new documents into its knowledge.""" + + # Create a new document with additional content + new_doc = Document( + external_id="alice_ch2", + owner={"id": "test_user", "name": "Test User"}, + content_type="text/plain", + system_metadata={ + "content": "Alice found herself in a pool of tears. She met a Mouse swimming in the pool.", + "title": "Alice in Wonderland - Additional Content", + "source": "test_document", + }, + ) + + # Add the new document + success = llama_cache.add_docs([new_doc]) + assert success + + # Query about the new content + response = llama_cache.query( + "What did Alice find in the pool of tears? Please respond in a single sentence. Answer: " + ) + assert isinstance(response, CompletionResponse) + assert "mouse" in response.completion.lower() + + +def test_cache_state_persistence(): + """Test that the cache state can be saved and loaded.""" + + # Create initial cache + doc = get_test_document() + original_cache = LlamaCache( + name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc] + ) + + # Get the state + state_bytes = original_cache.saveable_state + # Save state bytes to temporary file + import tempfile + import os + import pickle + + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = os.path.join(temp_dir, "cache.pkl") + + # Save to file + with open(cache_file, "wb") as f: + pickle.dump(state_bytes, f) + + # Load from file + with open(cache_file, "rb") as f: + loaded_state_bytes = pickle.load(f) + + # # Verify state bytes match + # assert state_bytes == loaded_state_bytes + # state_bytes = loaded_state_bytes # Use loaded bytes for rest of test + + # Create new cache from state + loaded_cache = LlamaCache.from_bytes( + name="test_cache", + cache_bytes=loaded_state_bytes, + metadata={ + "model": TEST_MODEL, + "model_file": TEST_GGUF_FILE, + "filters": {}, + "docs": [doc.model_dump_json()], + }, + ) + + # Verify the loaded cache works + response = loaded_cache.query( + "Summarize the content of the document. Please respond in a single sentence. Summary: " + ) + assert isinstance(response, CompletionResponse) + assert "coding" in response.completion.lower() or "debug" in response.completion.lower() + # assert "bottle" in response.completion.lower() and "drink me" in response.completion.lower() diff --git a/core/tests/unit/test_reranker.py b/core/tests/unit/test_reranker.py index fed68d3..69cc1a7 100644 --- a/core/tests/unit/test_reranker.py +++ b/core/tests/unit/test_reranker.py @@ -37,6 +37,8 @@ def sample_chunks() -> List[DocumentChunk]: def reranker(): """Fixture to create and reuse a flag reranker instance""" settings = get_settings() + if not settings.USE_RERANKING: + pytest.skip("Reranker is disabled in settings") return FlagReranker( model_name=settings.RERANKER_MODEL, device=settings.RERANKER_DEVICE, @@ -49,6 +51,8 @@ def reranker(): @pytest.mark.asyncio async def test_reranker_relevance(reranker, sample_chunks): """Test that reranker improves relevance for programming-related query""" + if not get_settings().USE_RERANKING: + pytest.skip("Reranker is disabled in settings") print("\n=== Testing Reranker Relevance ===") query = "What is Python programming language?" @@ -70,6 +74,8 @@ async def test_reranker_relevance(reranker, sample_chunks): @pytest.mark.asyncio async def test_reranker_score_distribution(reranker, sample_chunks): """Test that reranker produces reasonable score distribution""" + if not get_settings().USE_RERANKING: + pytest.skip("Reranker is disabled in settings") print("\n=== Testing Score Distribution ===") query = "Tell me about machine learning and data science" @@ -95,6 +101,8 @@ async def test_reranker_score_distribution(reranker, sample_chunks): @pytest.mark.asyncio async def test_reranker_batch_scoring(reranker): """Test that reranker can handle multiple queries/passages efficiently""" + if not get_settings().USE_RERANKING: + pytest.skip("Reranker is disabled in settings") print("\n=== Testing Batch Scoring ===") texts = [ "Python is a programming language", @@ -118,6 +126,8 @@ async def test_reranker_batch_scoring(reranker): @pytest.mark.asyncio async def test_reranker_empty_and_edge_cases(reranker, sample_chunks): """Test reranker behavior with empty or edge case inputs""" + if not get_settings().USE_RERANKING: + pytest.skip("Reranker is disabled in settings") print("\n=== Testing Edge Cases ===") # Empty chunks list @@ -147,6 +157,8 @@ async def test_reranker_empty_and_edge_cases(reranker, sample_chunks): @pytest.mark.asyncio async def test_reranker_consistency(reranker, sample_chunks): """Test that reranker produces consistent results for same input""" + if not get_settings().USE_RERANKING: + pytest.skip("Reranker is disabled in settings") print("\n=== Testing Consistency ===") query = "What is Python programming?" diff --git a/dockerfile b/dockerfile index f8ed828..798e5d0 100644 --- a/dockerfile +++ b/dockerfile @@ -9,6 +9,8 @@ WORKDIR /app # Install build dependencies RUN apt-get update && apt-get install -y \ gcc \ + g++ \ + cmake \ python3-dev \ && rm -rf /var/lib/apt/lists/* @@ -16,6 +18,9 @@ RUN apt-get update && apt-get install -y \ COPY requirements.txt . RUN pip install --no-cache-dir --user -r requirements.txt +# Download NLTK data +RUN python -m nltk.downloader -d /usr/local/share/nltk_data punkt averaged_perceptron_tagger + # Production stage FROM python:3.12.5-slim @@ -31,11 +36,17 @@ RUN apt-get update && apt-get install -y \ tesseract-ocr \ postgresql-client \ poppler-utils \ + gcc \ + g++ \ + cmake \ + python3-dev \ && rm -rf /var/lib/apt/lists/* # Copy installed packages from builder COPY --from=builder /root/.local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages COPY --from=builder /root/.local/bin /usr/local/bin +# Copy NLTK data from builder +COPY --from=builder /usr/local/share/nltk_data /usr/local/share/nltk_data # Create necessary directories RUN mkdir -p storage logs diff --git a/init.sql b/init.sql index 179a0d0..021e79e 100644 --- a/init.sql +++ b/init.sql @@ -31,7 +31,15 @@ CREATE TABLE IF NOT EXISTS vector_embeddings ( created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); +-- Create caches table +CREATE TABLE IF NOT EXISTS caches ( + name TEXT PRIMARY KEY, + metadata JSON NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); + -- Create vector index -CREATE INDEX IF NOT EXISTS vector_idx +CREATE INDEX IF NOT EXISTS vector_idx ON vector_embeddings USING ivfflat (embedding vector_l2_ops) -WITH (lists = 100); \ No newline at end of file +WITH (lists = 100); diff --git a/quick_setup.py b/quick_setup.py index 649b7b4..85ed4bd 100644 --- a/quick_setup.py +++ b/quick_setup.py @@ -241,12 +241,23 @@ def setup_postgres(): # Import and create all tables from core.database.postgres_database import Base - from core.vector_store.pgvector_store import Base as VectorBase # Create regular tables first await conn.run_sync(Base.metadata.create_all) LOGGER.info("Created base PostgreSQL tables") + # Create caches table + create_caches_table = """ + CREATE TABLE IF NOT EXISTS caches ( + name TEXT PRIMARY KEY, + metadata JSON NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ) + """ + await conn.execute(text(create_caches_table)) + LOGGER.info("Created caches table") + # Get vector dimensions from config dimensions = CONFIG["embedding"]["dimensions"] @@ -278,8 +289,8 @@ def setup_postgres(): LOGGER.info("Created vector_embeddings table with vector column") # Create the vector index - index_sql = f""" - CREATE INDEX vector_idx + index_sql = """ + CREATE INDEX vector_idx ON vector_embeddings USING ivfflat (embedding vector_l2_ops) WITH (lists = 100); """ diff --git a/requirements.txt b/requirements.txt index 6a8149f..8f76884 100644 --- a/requirements.txt +++ b/requirements.txt @@ -293,3 +293,4 @@ zlib-state==0.1.9 pgvector==0.2.5 psycopg[binary]==3.1.18 psycopg-binary==3.1.18 +llama-cpp-python==0.3.5 diff --git a/sdks/python/databridge/__init__.py b/sdks/python/databridge/__init__.py index 8bede8a..8b0c0b8 100644 --- a/sdks/python/databridge/__init__.py +++ b/sdks/python/databridge/__init__.py @@ -13,4 +13,4 @@ __all__ = [ "IngestTextRequest", ] -__version__ = "0.1.8" +__version__ = "0.2.0" diff --git a/sdks/python/databridge/async_.py b/sdks/python/databridge/async_.py index c9aed84..bb6287f 100644 --- a/sdks/python/databridge/async_.py +++ b/sdks/python/databridge/async_.py @@ -16,6 +16,30 @@ from .models import ( ) +class AsyncCache: + def __init__(self, db: "AsyncDataBridge", name: str): + self._db = db + self._name = name + + async def update(self) -> bool: + response = await self._db._request("POST", f"cache/{self._name}/update") + return response.get("success", False) + + async def add_docs(self, docs: List[str]) -> bool: + response = await self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs}) + return response.get("success", False) + + async def query( + self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None + ) -> str: + response = await self._db._request( + "POST", + f"cache/{self._name}/query", + {"query": query, "max_tokens": max_tokens, "temperature": temperature}, + ) + return CompletionResponse(**response) + + class AsyncDataBridge: """ DataBridge client for document operations. @@ -345,6 +369,72 @@ class AsyncDataBridge: response = await self._request("GET", f"documents/{document_id}") return Document(**response) + async def create_cache( + self, + name: str, + model: str, + gguf_file: str, + filters: Optional[Dict[str, Any]] = None, + docs: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Create a new cache with specified configuration. + + Args: + name: Name of the cache to create + model: Name of the model to use (e.g. "llama2") + gguf_file: Name of the GGUF file to use for the model + filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided. + docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters. + + Returns: + Dict[str, Any]: Created cache configuration + + Example: + ```python + # This will include both: + # 1. Any documents with category="programming" + # 2. The specific documents "doc1" and "doc2" (regardless of their category) + cache = await db.create_cache( + name="programming_cache", + model="llama2", + gguf_file="llama-2-7b-chat.Q4_K_M.gguf", + filters={"category": "programming"}, + docs=["doc1", "doc2"] + ) + ``` + """ + request = { + "name": name, + "model": model, + "gguf_file": gguf_file, + "filters": filters, + "docs": docs, + } + + response = await self._request("POST", "cache/create", request) + return response + + async def get_cache(self, name: str) -> AsyncCache: + """ + Get a cache by name. + + Args: + name: Name of the cache to retrieve + + Returns: + cache: A cache object that is used to interact with the cache. + + Example: + ```python + cache = await db.get_cache("programming_cache") + ``` + """ + response = await self._request("GET", f"cache/{name}") + if response.get("exists", False): + return AsyncCache(self, name) + raise ValueError(f"Cache '{name}' not found") + async def close(self): """Close the HTTP client""" await self._client.aclose() diff --git a/sdks/python/databridge/sync.py b/sdks/python/databridge/sync.py index e1db50a..3bc8b5c 100644 --- a/sdks/python/databridge/sync.py +++ b/sdks/python/databridge/sync.py @@ -16,6 +16,31 @@ from .models import ( ) +class Cache: + def __init__(self, db: "DataBridge", name: str): + self._db = db + self._name = name + + def update(self) -> bool: + response = self._db._request("POST", f"cache/{self._name}/update") + return response.get("success", False) + + def add_docs(self, docs: List[str]) -> bool: + response = self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs}) + return response.get("success", False) + + def query( + self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None + ) -> CompletionResponse: + response = self._db._request( + "POST", + f"cache/{self._name}/query", + params={"query": query, "max_tokens": max_tokens, "temperature": temperature}, + data="", + ) + return CompletionResponse(**response) + + class DataBridge: """ DataBridge client for document operations. @@ -71,6 +96,7 @@ class DataBridge: endpoint: str, data: Optional[Dict[str, Any]] = None, files: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Make HTTP request""" headers = {} @@ -88,6 +114,7 @@ class DataBridge: data=data if files else None, headers=headers, timeout=self._timeout, + params=params, ) response.raise_for_status() return response.json() @@ -334,6 +361,70 @@ class DataBridge: response = self._request("GET", f"documents/{document_id}") return Document(**response) + def create_cache( + self, + name: str, + model: str, + gguf_file: str, + filters: Optional[Dict[str, Any]] = None, + docs: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + Create a new cache with specified configuration. + + Args: + name: Name of the cache to create + model: Name of the model to use (e.g. "llama2") + gguf_file: Name of the GGUF file to use for the model + filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided. + docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters. + + Returns: + Dict[str, Any]: Created cache configuration + + Example: + ```python + # This will include both: + # 1. Any documents with category="programming" + # 2. The specific documents "doc1" and "doc2" (regardless of their category) + cache = db.create_cache( + name="programming_cache", + model="llama2", + gguf_file="llama-2-7b-chat.Q4_K_M.gguf", + filters={"category": "programming"}, + docs=["doc1", "doc2"] + ) + ``` + """ + # Build query parameters for name, model and gguf_file + params = {"name": name, "model": model, "gguf_file": gguf_file} + + # Build request body for filters and docs + request = {"filters": filters, "docs": docs} + + response = self._request("POST", "cache/create", request, params=params) + return response + + def get_cache(self, name: str) -> Cache: + """ + Get a cache by name. + + Args: + name: Name of the cache to retrieve + + Returns: + cache: A cache object that is used to interact with the cache. + + Example: + ```python + cache = db.get_cache("programming_cache") + ``` + """ + response = self._request("GET", f"cache/{name}") + if response.get("exists", False): + return Cache(self, name) + raise ValueError(f"Cache '{name}' not found") + def close(self): """Close the HTTP session""" self._session.close() diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index 96b19d4..6173a9e 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "databridge-client" -version = "0.1.8" +version = "0.2.0" authors = [ { name = "DataBridge", email = "databridgesuperuser@gmail.com" }, ] diff --git a/shell.py b/shell.py index 38c1b02..edb568a 100644 --- a/shell.py +++ b/shell.py @@ -118,11 +118,59 @@ class DB: doc = self._client.get_document(document_id) return doc.model_dump() + def create_cache( + self, + name: str, + model: str, + gguf_file: str, + filters: dict = None, + docs: list = None, + ) -> dict: + """Create a new cache with specified configuration""" + response = self._client.create_cache( + name=name, + model=model, + gguf_file=gguf_file, + filters=filters or {}, + docs=docs, + ) + return response + + def get_cache(self, name: str) -> "Cache": + """Get a cache by name""" + return self._client.get_cache(name) + def close(self): """Close the client connection""" self._client.close() +class Cache: + def __init__(self, db: DB, name: str): + self._db = db + self._name = name + self._client_cache = db._client.get_cache(name) + + def update(self) -> bool: + """Update the cache""" + return self._client_cache.update() + + def add_docs(self, docs: list) -> bool: + """Add documents to the cache""" + return self._client_cache.add_docs(docs) + + def query( + self, query: str, max_tokens: int = None, temperature: float = None + ) -> dict: + """Query the cache""" + response = self._client_cache.query( + query=query, + max_tokens=max_tokens, + temperature=temperature, + ) + return response.model_dump() + + if __name__ == "__main__": uri = sys.argv[1] if len(sys.argv) > 1 else None db = DB(uri)