Add support for cache-augmented-generation (#30)

This commit is contained in:
Arnav Agrawal 2025-01-29 10:19:28 +05:30 committed by GitHub
parent 4f0cf62008
commit d124e6aa0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1509 additions and 31 deletions

View File

@ -1,5 +1,7 @@
import json
from datetime import datetime, UTC, timedelta
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
from fastapi.middleware.cors import CORSMiddleware
@ -30,6 +32,7 @@ from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
from core.completion.ollama_completion import OllamaCompletionModel
from core.parser.contextual_parser import ContextualParser
from core.reranker.flag_reranker import FlagReranker
from core.cache.llama_cache_factory import LlamaCacheFactory
import tomli
# Initialize FastAPI app
@ -214,6 +217,9 @@ if settings.USE_RERANKING:
case _:
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
# Initialize cache factory
cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH))
# Initialize document service with configured components
document_service = DocumentService(
storage=storage,
@ -223,6 +229,7 @@ document_service = DocumentService(
completion_model=completion_model,
parser=parser,
reranker=reranker,
cache_factory=cache_factory,
)
@ -462,6 +469,134 @@ async def get_recent_usage(
]
# Cache endpoints
@app.post("/cache/create")
async def create_cache(
name: str,
model: str,
gguf_file: str,
filters: Optional[Dict[str, Any]] = None,
docs: Optional[List[str]] = None,
auth: AuthContext = Depends(verify_token),
) -> Dict[str, Any]:
"""Create a new cache with specified configuration."""
try:
async with telemetry.track_operation(
operation_type="create_cache",
user_id=auth.entity_id,
metadata={
"name": name,
"model": model,
"gguf_file": gguf_file,
"filters": filters,
"docs": docs,
},
):
filter_docs = set(await document_service.db.get_documents(auth, filters=filters))
additional_docs = (
{
await document_service.db.get_document(document_id=doc_id, auth=auth)
for doc_id in docs
}
if docs
else set()
)
docs_to_add = list(filter_docs.union(additional_docs))
if not docs_to_add:
raise HTTPException(status_code=400, detail="No documents to add to cache")
response = await document_service.create_cache(
name, model, gguf_file, docs_to_add, filters
)
return response
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.get("/cache/{name}")
async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]:
"""Get cache configuration by name."""
try:
async with telemetry.track_operation(
operation_type="get_cache",
user_id=auth.entity_id,
metadata={"name": name},
):
exists = await document_service.load_cache(name)
return {"exists": exists}
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/update")
async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]:
"""Update cache with new documents matching its filter."""
try:
async with telemetry.track_operation(
operation_type="update_cache",
user_id=auth.entity_id,
metadata={"name": name},
):
if name not in document_service.active_caches:
exists = await document_service.load_cache(name)
if not exists:
raise HTTPException(status_code=404, detail=f"Cache '{name}' not found")
cache = document_service.active_caches[name]
docs = await document_service.db.get_documents(auth, filters=cache.filters)
docs_to_add = [doc for doc in docs if doc.id not in cache.docs]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/add_docs")
async def add_docs_to_cache(
name: str, docs: List[str], auth: AuthContext = Depends(verify_token)
) -> Dict[str, bool]:
"""Add specific documents to the cache."""
try:
async with telemetry.track_operation(
operation_type="add_docs_to_cache",
user_id=auth.entity_id,
metadata={"name": name, "docs": docs},
):
cache = document_service.active_caches[name]
docs_to_add = [
await document_service.db.get_document(doc_id, auth)
for doc_id in docs
if doc_id not in cache.docs
]
return cache.add_docs(docs_to_add)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/cache/{name}/query")
async def query_cache(
name: str,
query: str,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
auth: AuthContext = Depends(verify_token),
) -> CompletionResponse:
"""Query the cache with a prompt."""
try:
async with telemetry.track_operation(
operation_type="query_cache",
user_id=auth.entity_id,
metadata={
"name": name,
"query": query,
"max_tokens": max_tokens,
"temperature": temperature,
},
):
cache = document_service.active_caches[name]
print(f"Cache state: {cache.state.n_tokens}", file=sys.stderr)
return cache.query(query) # , max_tokens, temperature)
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/local/generate_uri", include_in_schema=True)
async def generate_local_uri(
name: str = Form("admin"),

68
core/cache/base_cache.py vendored Normal file
View File

@ -0,0 +1,68 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from core.models.completion import CompletionResponse
from core.models.documents import Document
class BaseCache(ABC):
"""Base class for cache implementations.
This class defines the interface for cache implementations that support
document ingestion and cache-augmented querying.
"""
def __init__(
self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document]
):
"""Initialize the cache with the given parameters.
Args:
name: Name of the cache instance
model: Model identifier
gguf_file: Path to the GGUF model file
filters: Filters used to create the cache context
docs: Initial documents to ingest into the cache
"""
self.name = name
self.filters = filters
self.docs = [] # List of document IDs that have been ingested
self._initialize(model, gguf_file, docs)
@abstractmethod
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
"""Internal initialization method to be implemented by subclasses."""
pass
@abstractmethod
async def add_docs(self, docs: List[Document]) -> bool:
"""Add documents to the cache.
Args:
docs: List of documents to add to the cache
Returns:
bool: True if documents were successfully added
"""
pass
@abstractmethod
async def query(self, query: str) -> CompletionResponse:
"""Query the cache for relevant documents and generate a response.
Args:
query: Query string to search for relevant documents
Returns:
CompletionResponse: Generated response based on cached context
"""
pass
@property
@abstractmethod
def saveable_state(self) -> bytes:
"""Get the saveable state of the cache as bytes.
Returns:
bytes: Serialized state that can be used to restore the cache
"""
pass

64
core/cache/base_cache_factory.py vendored Normal file
View File

@ -0,0 +1,64 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any
from .base_cache import BaseCache
class BaseCacheFactory(ABC):
"""Abstract base factory for creating and loading caches."""
def __init__(self, storage_path: Path):
"""Initialize the cache factory.
Args:
storage_path: Base path for storing cache files
"""
self.storage_path = storage_path
self.storage_path.mkdir(parents=True, exist_ok=True)
@abstractmethod
def create_new_cache(
self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]
) -> BaseCache:
"""Create a new cache instance.
Args:
name: Name of the cache
model: Name/type of the model to use
model_file: Path or identifier for the model file
**kwargs: Additional arguments for cache creation
Returns:
BaseCache: The created cache instance
"""
pass
@abstractmethod
def load_cache_from_bytes(
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
) -> BaseCache:
"""Load a cache from its serialized bytes.
Args:
name: Name of the cache
cache_bytes: Serialized cache data
metadata: Cache metadata including model info
**kwargs: Additional arguments for cache loading
Returns:
BaseCache: The loaded cache instance
"""
pass
def get_cache_path(self, name: str) -> Path:
"""Get the storage path for a cache.
Args:
name: Name of the cache
Returns:
Path: Directory path for the cache
"""
path = self.storage_path / name
path.mkdir(parents=True, exist_ok=True)
return path

285
core/cache/hf_cache.py vendored Normal file
View File

@ -0,0 +1,285 @@
# hugging face cache implementation.
from core.cache.base_cache import BaseCache
from typing import List, Optional, Union
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
from core.models.completion import CompletionRequest, CompletionResponse
class HuggingFaceCache(BaseCache):
"""Hugging Face Cache implementation for cache-augmented generation"""
def __init__(
self,
cache_path: Path,
model_name: str = "distilgpt2",
device: str = "cpu",
default_max_new_tokens: int = 100,
use_fp16: bool = False,
):
"""Initialize the HuggingFace cache.
Args:
cache_path: Path to store cache files
model_name: Name of the HuggingFace model to use
device: Device to run the model on (e.g. "cpu", "cuda", "mps")
default_max_new_tokens: Default maximum number of new tokens to generate
use_fp16: Whether to use FP16 precision
"""
super().__init__()
self.cache_path = cache_path
self.model_name = model_name
self.device = device
self.default_max_new_tokens = default_max_new_tokens
self.use_fp16 = use_fp16
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configure model loading based on device
model_kwargs = {"low_cpu_mem_usage": True}
if device == "cpu":
# For CPU, use standard loading
model_kwargs.update({"torch_dtype": torch.float32})
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device)
else:
# For GPU/MPS, use automatic device mapping and optional FP16
model_kwargs.update(
{"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32}
)
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
self.kv_cache = None
self.origin_len = None
def get_kv_cache(self, prompt: str) -> DynamicCache:
"""Build KV cache from prompt"""
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
cache = DynamicCache()
with torch.no_grad():
_ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True)
return cache
def clean_up_cache(self, cache: DynamicCache, origin_len: int):
"""Clean up cache by removing appended tokens"""
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
def generate(
self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None
) -> torch.Tensor:
"""Generate text using the model and cache"""
device = next(self.model.parameters()).device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens or self.default_max_new_tokens):
out = self.model(
input_ids=next_token, past_key_values=past_key_values, use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if (
self.model.config.eos_token_id is not None
and token.item() == self.model.config.eos_token_id
):
break
return output_ids[:, origin_len:]
async def ingest(self, docs: List[str]) -> bool:
"""Ingest documents into cache"""
try:
# Create system prompt with documents
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{' '.join(docs)}
Question:
""".strip()
# Build the cache
input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device)
self.kv_cache = DynamicCache()
with torch.no_grad():
# First run to get the cache shape
outputs = self.model(input_ids=input_ids, use_cache=True)
# Initialize cache with empty tensors of the right shape
n_layers = len(outputs.past_key_values)
batch_size = input_ids.shape[0]
# Handle different model architectures
if hasattr(self.model.config, "num_key_value_heads"):
# Models with grouped query attention (GQA) like Llama
n_kv_heads = self.model.config.num_key_value_heads
head_dim = self.model.config.head_dim
elif hasattr(self.model.config, "n_head"):
# GPT-style models
n_kv_heads = self.model.config.n_head
head_dim = self.model.config.n_embd // self.model.config.n_head
elif hasattr(self.model.config, "num_attention_heads"):
# OPT-style models
n_kv_heads = self.model.config.num_attention_heads
head_dim = (
self.model.config.hidden_size // self.model.config.num_attention_heads
)
else:
raise ValueError(
f"Unsupported model architecture: {self.model.config.model_type}"
)
seq_len = input_ids.shape[1]
for i in range(n_layers):
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
value_shape = key_shape
self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
# Now run with the initialized cache
outputs = self.model(
input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True
)
# Update cache with actual values
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
self.origin_len = self.kv_cache.key_cache[0].shape[-2]
return True
except Exception as e:
print(f"Error ingesting documents: {e}")
return False
async def update(self, new_doc: str) -> bool:
"""Update cache with new document"""
try:
if self.kv_cache is None:
return await self.ingest([new_doc])
# Clean up existing cache
self.clean_up_cache(self.kv_cache, self.origin_len)
# Add new document to cache
input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to(
self.device
)
# First run to get the cache shape
outputs = self.model(input_ids=input_ids, use_cache=True)
# Initialize cache with empty tensors of the right shape
n_layers = len(outputs.past_key_values)
batch_size = input_ids.shape[0]
# Handle different model architectures
if hasattr(self.model.config, "num_key_value_heads"):
# Models with grouped query attention (GQA) like Llama
n_kv_heads = self.model.config.num_key_value_heads
head_dim = self.model.config.head_dim
elif hasattr(self.model.config, "n_head"):
# GPT-style models
n_kv_heads = self.model.config.n_head
head_dim = self.model.config.n_embd // self.model.config.n_head
elif hasattr(self.model.config, "num_attention_heads"):
# OPT-style models
n_kv_heads = self.model.config.num_attention_heads
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
else:
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
seq_len = input_ids.shape[1]
# Create a new cache for the update
new_cache = DynamicCache()
for i in range(n_layers):
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
value_shape = key_shape
new_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
new_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
# Run with the initialized cache
outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True)
# Update cache with actual values
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
return True
except Exception as e:
print(f"Error updating cache: {e}")
return False
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using cache-augmented generation"""
try:
if self.kv_cache is None:
raise ValueError("Cache not initialized. Please ingest documents first.")
# Clean up cache
self.clean_up_cache(self.kv_cache, self.origin_len)
# Generate completion
input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to(
self.device
)
gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens)
completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# Calculate token usage
usage = {
"prompt_tokens": len(input_ids[0]),
"completion_tokens": len(gen_ids[0]),
"total_tokens": len(input_ids[0]) + len(gen_ids[0]),
}
return CompletionResponse(completion=completion, usage=usage)
except Exception as e:
print(f"Error generating completion: {e}")
return CompletionResponse(
completion=f"Error: {str(e)}",
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
def save_cache(self) -> Path:
"""Save the KV cache to disk"""
if self.kv_cache is None:
raise ValueError("No cache to save")
cache_dir = self.cache_path / "kv_cache"
cache_dir.mkdir(parents=True, exist_ok=True)
# Save key and value caches
cache_data = {
"key_cache": self.kv_cache.key_cache,
"value_cache": self.kv_cache.value_cache,
"origin_len": self.origin_len,
}
cache_path = cache_dir / "cache.pt"
torch.save(cache_data, cache_path)
return cache_path
def load_cache(self, cache_path: Union[str, Path]) -> None:
"""Load KV cache from disk"""
cache_path = Path(cache_path)
if not cache_path.exists():
raise FileNotFoundError(f"Cache file not found at {cache_path}")
cache_data = torch.load(cache_path, map_location=self.device)
self.kv_cache = DynamicCache()
self.kv_cache.key_cache = cache_data["key_cache"]
self.kv_cache.value_cache = cache_data["value_cache"]
self.origin_len = cache_data["origin_len"]

196
core/cache/llama_cache.py vendored Normal file
View File

@ -0,0 +1,196 @@
import json
import pickle
import logging
from core.cache.base_cache import BaseCache
from typing import Dict, Any, List
from core.models.completion import CompletionResponse
from core.models.documents import Document
from llama_cpp import Llama
logger = logging.getLogger(__name__)
INITIAL_SYSTEM_PROMPT = """<|im_start|>system
You are a helpful AI assistant with access to provided documents. Your role is to:
1. Answer questions accurately based on the documents provided
2. Stay focused on the document content and avoid speculation
3. Admit when you don't have enough information to answer
4. Be clear and concise in your responses
5. Use direct quotes from documents when relevant
Provided documents: {documents}
<|im_end|>
""".strip()
ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system
I'm adding some additional documents for your reference:
{documents}
Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses.
<|im_end|>
""".strip()
QUERY_PROMPT = """<|im_start|>user
{query}
<|im_end|>
<|im_start|>assistant
""".strip()
class LlamaCache(BaseCache):
def __init__(
self,
name: str,
model: str,
gguf_file: str,
filters: Dict[str, Any],
docs: List[Document],
**kwargs,
):
logger.info(f"Initializing LlamaCache with name={name}, model={model}")
# cache related
self.name = name
self.model = model
self.filters = filters
self.docs = docs
# llama specific
self.gguf_file = gguf_file
self.n_gpu_layers = kwargs.get("n_gpu_layers", -1)
logger.info(f"Using {self.n_gpu_layers} GPU layers")
# late init (when we call _initialize)
self.llama = None
self.state = None
self.cached_tokens = 0
self._initialize(model, gguf_file, docs)
logger.info("LlamaCache initialization complete")
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
logger.info(f"Loading Llama model from {model} with file {gguf_file}")
try:
# Set a reasonable default context size (32K tokens)
default_ctx_size = 32768
self.llama = Llama.from_pretrained(
repo_id=model,
filename=gguf_file,
n_gpu_layers=self.n_gpu_layers,
n_ctx=default_ctx_size,
verbose=False, # Enable verbose mode for better error reporting
)
logger.info("Model loaded successfully")
# Format and tokenize system prompt
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents)
logger.info(f"Built system prompt: {system_prompt[:200]}...")
try:
tokens = self.llama.tokenize(system_prompt.encode())
logger.info(f"System prompt tokenized to {len(tokens)} tokens")
# Process tokens to build KV cache
logger.info("Evaluating system prompt")
self.llama.eval(tokens)
logger.info("Saving initial KV cache state")
self.state = self.llama.save_state()
self.cached_tokens = len(tokens)
logger.info(f"Initial KV cache built with {self.cached_tokens} tokens")
except Exception as e:
logger.error(f"Error during prompt processing: {str(e)}")
raise ValueError(f"Failed to process system prompt: {str(e)}")
except Exception as e:
logger.error(f"Failed to initialize Llama model: {str(e)}")
raise ValueError(f"Failed to initialize Llama model: {str(e)}")
def add_docs(self, docs: List[Document]) -> bool:
logger.info(f"Adding {len(docs)} new documents to cache")
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents)
# Tokenize and process
new_tokens = self.llama.tokenize(system_prompt.encode())
self.llama.eval(new_tokens)
self.state = self.llama.save_state()
self.cached_tokens += len(new_tokens)
logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}")
return True
def query(self, query: str) -> CompletionResponse:
# Format query with proper chat template
formatted_query = QUERY_PROMPT.format(query=query)
logger.info(f"Processing query: {formatted_query}")
# Reset and load cached state
self.llama.reset()
self.llama.load_state(self.state)
logger.info(f"Loaded state with {self.state.n_tokens} tokens")
# print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr)
# Tokenize and process query
query_tokens = self.llama.tokenize(formatted_query.encode())
self.llama.eval(query_tokens)
logger.info(f"Evaluated query tokens: {query_tokens}")
# print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr)
# Generate response
output_tokens = []
for token in self.llama.generate(tokens=[], reset=False):
output_tokens.append(token)
# Stop generation when EOT token is encountered
if token == self.llama.token_eos():
break
# Decode and return
completion = self.llama.detokenize(output_tokens).decode()
logger.info(f"Generated completion: {completion}")
return CompletionResponse(
completion=completion,
usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)},
)
@property
def saveable_state(self) -> bytes:
logger.info("Serializing cache state")
state_bytes = pickle.dumps(self.state)
logger.info(f"Serialized state size: {len(state_bytes)} bytes")
return state_bytes
@classmethod
def from_bytes(
cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs
) -> "LlamaCache":
"""Load a cache from its serialized state.
Args:
name: Name of the cache
cache_bytes: Pickled state bytes
metadata: Cache metadata including model info
**kwargs: Additional arguments
Returns:
LlamaCache: Loaded cache instance
"""
logger.info(f"Loading cache from bytes with name={name}")
logger.info(f"Cache metadata: {metadata}")
# Create new instance with metadata
# logger.info(f"Docs: {metadata['docs']}")
docs = [json.loads(doc) for doc in metadata["docs"]]
# time.sleep(10)
cache = cls(
name=name,
model=metadata["model"],
gguf_file=metadata["model_file"],
filters=metadata["filters"],
docs=[Document(**doc) for doc in docs],
)
# Load the saved state
logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes")
cache.state = pickle.loads(cache_bytes)
cache.llama.load_state(cache.state)
logger.info("Cache successfully loaded from bytes")
return cache

15
core/cache/llama_cache_factory.py vendored Normal file
View File

@ -0,0 +1,15 @@
from core.cache.base_cache_factory import BaseCacheFactory
from core.cache.llama_cache import LlamaCache
from typing import Dict, Any
class LlamaCacheFactory(BaseCacheFactory):
def create_new_cache(
self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]
) -> LlamaCache:
return LlamaCache(name, model, model_file, **kwargs)
def load_cache_from_bytes(
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
) -> LlamaCache:
return LlamaCache.from_bytes(name, cache_bytes, metadata, **kwargs)

View File

@ -1,22 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict
from pydantic import BaseModel
class CompletionResponse(BaseModel):
"""Response from completion generation"""
completion: str
usage: Dict[str, int]
class CompletionRequest(BaseModel):
"""Request for completion generation"""
query: str
context_chunks: List[str]
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.7
from core.models.completion import CompletionRequest, CompletionResponse
class BaseCompletionModel(ABC):

View File

@ -72,3 +72,28 @@ class BaseDatabase(ABC):
Returns: True if user has required access, False otherwise
"""
pass
@abstractmethod
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
pass
@abstractmethod
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
pass

View File

@ -26,6 +26,7 @@ class MongoDatabase(BaseDatabase):
self.client = AsyncIOMotorClient(uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.caches = self.db["caches"] # Collection for cache metadata
async def initialize(self):
"""Initialize database indexes."""
@ -217,3 +218,45 @@ class MongoDatabase(BaseDatabase):
for key, value in filters.items():
filter_dict[f"metadata.{key}"] = value
return filter_dict
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in MongoDB.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
# Add timestamp and ensure name is included
doc = {
"name": name,
"metadata": metadata,
"created_at": datetime.now(UTC),
"updated_at": datetime.now(UTC),
}
# Upsert the document
result = await self.caches.update_one({"name": name}, {"$set": doc}, upsert=True)
return bool(result.modified_count or result.upserted_id)
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from MongoDB.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
doc = await self.caches.find_one({"name": name})
return doc["metadata"] if doc else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None

View File

@ -1,3 +1,4 @@
import json
from typing import List, Optional, Dict, Any
from datetime import datetime, UTC
import logging
@ -65,8 +66,24 @@ class PostgresDatabase(BaseDatabase):
try:
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Create caches table if it doesn't exist
await conn.execute(
text(
"""
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
logger.info("PostgreSQL tables and indexes created successfully")
return True
except Exception as e:
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
return False
@ -324,3 +341,54 @@ class PostgresDatabase(BaseDatabase):
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
return " AND ".join(filter_conditions)
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
"""Store metadata for a cache in PostgreSQL.
Args:
name: Name of the cache
metadata: Cache metadata including model info and storage location
Returns:
bool: Whether the operation was successful
"""
try:
async with self.async_session() as session:
await session.execute(
text(
"""
INSERT INTO caches (name, metadata, updated_at)
VALUES (:name, :metadata, CURRENT_TIMESTAMP)
ON CONFLICT (name)
DO UPDATE SET
metadata = :metadata,
updated_at = CURRENT_TIMESTAMP
"""
),
{"name": name, "metadata": json.dumps(metadata)},
)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to store cache metadata: {e}")
return False
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a cache from PostgreSQL.
Args:
name: Name of the cache
Returns:
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
"""
try:
async with self.async_session() as session:
result = await session.execute(
text("SELECT metadata FROM caches WHERE name = :name"), {"name": name}
)
row = result.first()
return row[0] if row else None
except Exception as e:
logger.error(f"Failed to get cache metadata: {e}")
return None

19
core/models/completion.py Normal file
View File

@ -0,0 +1,19 @@
from pydantic import BaseModel
from typing import Dict, List, Optional
class CompletionResponse(BaseModel):
"""Response from completion generation"""
completion: str
usage: Dict[str, int]
finish_reason: Optional[str] = None
class CompletionRequest(BaseModel):
"""Request for completion generation"""
query: str
context_chunks: List[str]
max_tokens: Optional[int] = 1000
temperature: Optional[float] = 0.7

View File

@ -40,6 +40,14 @@ class Document(BaseModel):
)
chunk_ids: List[str] = Field(default_factory=list)
def __hash__(self):
return hash(self.external_id)
def __eq__(self, other):
if not isinstance(other, Document):
return False
return self.external_id == other.external_id
class DocumentContent(BaseModel):
"""Represents either a URL or content string"""

View File

@ -21,6 +21,8 @@ from core.completion.base_completion import CompletionRequest, CompletionRespons
import logging
from core.reranker.base_reranker import BaseReranker
from core.config import get_settings
from core.cache.base_cache import BaseCache
from core.cache.base_cache_factory import BaseCacheFactory
logger = logging.getLogger(__name__)
@ -34,6 +36,7 @@ class DocumentService:
parser: BaseParser,
embedding_model: BaseEmbeddingModel,
completion_model: BaseCompletionModel,
cache_factory: BaseCacheFactory,
reranker: Optional[BaseReranker] = None,
):
self.db = database
@ -43,6 +46,11 @@ class DocumentService:
self.embedding_model = embedding_model
self.completion_model = completion_model
self.reranker = reranker
self.cache_factory = cache_factory
# Cache-related data structures
# Maps cache name to active cache object
self.active_caches: Dict[str, BaseCache] = {}
async def retrieve_chunks(
self,
@ -151,6 +159,7 @@ class DocumentService:
},
)
logger.info(f"Created text document record with ID {doc.external_id}")
doc.system_metadata["content"] = request.content
# 2. Parse content into chunks
chunks = await self.parser.split_text(request.content)
@ -196,6 +205,7 @@ class DocumentService:
},
additional_metadata=additional_metadata,
)
doc.system_metadata["content"] = "\n".join(chunk.content for chunk in chunks)
logger.info(f"Created file document record with ID {doc.external_id}")
storage_info = await self.storage.upload_from_base64(
@ -333,3 +343,90 @@ class DocumentService:
logger.info(f"Created {len(results)} document results")
return results
async def create_cache(
self,
name: str,
model: str,
gguf_file: str,
docs: List[Document | None],
filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""Create a new cache with specified configuration.
Args:
name: Name of the cache to create
model: Name of the model to use
gguf_file: Name of the GGUF file to use
filters: Optional metadata filters for documents to include
docs: Optional list of specific document IDs to include
"""
# Create cache metadata
metadata = {
"model": model,
"model_file": gguf_file,
"filters": filters,
"docs": [doc.model_dump_json() for doc in docs],
"storage_info": {
"bucket": "caches",
"key": f"{name}_state.pkl",
},
}
# Store metadata in database
success = await self.db.store_cache_metadata(name, metadata)
if not success:
logger.error(f"Failed to store cache metadata for cache {name}")
return {"success": False, "message": f"Failed to store cache metadata for cache {name}"}
# Create cache instance
cache = self.cache_factory.create_new_cache(
name=name, model=model, model_file=gguf_file, filters=filters, docs=docs
)
cache_bytes = cache.saveable_state
base64_cache_bytes = base64.b64encode(cache_bytes).decode()
bucket, key = await self.storage.upload_from_base64(
base64_cache_bytes,
key=metadata["storage_info"]["key"],
bucket=metadata["storage_info"]["bucket"],
)
return {
"success": True,
"message": f"Cache created successfully, state stored in bucket `{bucket}` with key `{key}`",
}
async def load_cache(self, name: str) -> bool:
"""Load a cache into memory.
Args:
name: Name of the cache to load
Returns:
bool: Whether the cache exists and was loaded successfully
"""
try:
# Get cache metadata from database
metadata = await self.db.get_cache_metadata(name)
if not metadata:
logger.error(f"No metadata found for cache {name}")
return False
# Get cache bytes from storage
cache_bytes = await self.storage.download_file(
metadata["storage_info"]["bucket"], "caches/" + metadata["storage_info"]["key"]
)
cache_bytes = cache_bytes.read()
cache = self.cache_factory.load_cache_from_bytes(
name=name, cache_bytes=cache_bytes, metadata=metadata
)
self.active_caches[name] = cache
return {"success": True, "message": "Cache loaded successfully"}
except Exception as e:
logger.error(f"Failed to load cache {name}: {e}")
# raise e
return {"success": False, "message": f"Failed to load cache {name}: {e}"}
def close(self):
"""Close all resources."""
# Close any active caches
self.active_caches.clear()

View File

@ -7,7 +7,7 @@ class BaseStorage(ABC):
@abstractmethod
async def upload_from_base64(
self, content: str, key: str, content_type: Optional[str] = None
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
) -> Tuple[str, str]:
"""
Upload base64 encoded content.
@ -16,7 +16,7 @@ class BaseStorage(ABC):
content: Base64 encoded content
key: Storage key/path
content_type: Optional MIME type
bucket: Optional bucket/folder name
Returns:
Tuple[str, str]: (bucket/container name, storage key)
"""

View File

@ -19,16 +19,20 @@ class LocalStorage(BaseStorage):
return open(file_path, "rb")
async def upload_from_base64(
self, base64_content: str, key: str, content_type: Optional[str] = None
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
) -> Tuple[str, str]:
base64_content = content
"""Upload base64 encoded content to local storage."""
# Decode base64 content
file_content = base64.b64decode(base64_content)
key = f"{bucket}/{key}" if bucket else key
# Create file path
file_path = self.storage_path / key
# Write content to file
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.unlink(missing_ok=True)
with open(file_path, "wb") as f:
f.write(file_content)

View File

@ -37,6 +37,7 @@ class S3Storage(BaseStorage):
file: Union[str, bytes, BinaryIO],
key: str,
content_type: Optional[str] = None,
bucket: str = "",
) -> Tuple[str, str]:
"""Upload a file to S3."""
try:
@ -70,15 +71,18 @@ class S3Storage(BaseStorage):
raise
async def upload_from_base64(
self, content: str, key: str, content_type: Optional[str] = None
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
) -> Tuple[str, str]:
"""Upload base64 encoded content to S3."""
key = f"{bucket}/{key}" if bucket else key
try:
decoded_content = base64.b64decode(content)
extension = detect_file_type(content)
key = f"{key}{extension}"
return await self.upload_file(file=decoded_content, key=key, content_type=content_type)
return await self.upload_file(
file=decoded_content, key=key, content_type=content_type, bucket=bucket
)
except Exception as e:
logger.error(f"Error uploading base64 content to S3: {e}")

View File

@ -238,6 +238,8 @@ async def test_ingest_oversized_content(client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_missing_header(client: AsyncClient):
"""Test authentication with missing auth header"""
if get_settings().dev_mode:
pytest.skip("Auth tests skipped in dev mode")
response = await client.post("/ingest/text")
assert response.status_code == 401
@ -245,6 +247,8 @@ async def test_auth_missing_header(client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_invalid_token(client: AsyncClient):
"""Test authentication with invalid token"""
if get_settings().dev_mode:
pytest.skip("Auth tests skipped in dev mode")
headers = {"Authorization": "Bearer invalid_token"}
response = await client.post("/ingest/file", headers=headers)
assert response.status_code == 401
@ -253,6 +257,8 @@ async def test_auth_invalid_token(client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_expired_token(client: AsyncClient):
"""Test authentication with expired token"""
if get_settings().dev_mode:
pytest.skip("Auth tests skipped in dev mode")
headers = create_auth_header(expired=True)
response = await client.post("/ingest/text", headers=headers)
assert response.status_code == 401
@ -261,6 +267,8 @@ async def test_auth_expired_token(client: AsyncClient):
@pytest.mark.asyncio
async def test_auth_insufficient_permissions(client: AsyncClient):
"""Test authentication with insufficient permissions"""
if get_settings().dev_mode:
pytest.skip("Auth tests skipped in dev mode")
headers = create_auth_header(permissions=["read"])
response = await client.post(
"/ingest/text",
@ -335,7 +343,7 @@ async def test_retrieve_chunks(client: AsyncClient):
assert response.status_code == 200
results = list(response.json())
assert len(results) > 0
assert results[0]["score"] > 0.5
assert (not get_settings().USE_RERANKING) or results[0]["score"] > 0.5
assert any(upload_string == result["content"] for result in results)

View File

@ -0,0 +1,184 @@
import pytest
from core.models.documents import Document
from core.cache.llama_cache import LlamaCache
from core.models.completion import CompletionResponse
# TEST_MODEL = "QuantFactory/Llama3.2-3B-Enigma-GGUF"
TEST_MODEL = "QuantFactory/Dolphin3.0-Llama3.2-1B-GGUF"
# TEST_MODEL = "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF"
TEST_GGUF_FILE = "*Q4_K_S.gguf"
# TEST_GGUF_FILE = "*Q4_K_M.gguf"
def get_test_document():
"""Load the example.txt file as a test document."""
# test_file = Path(__file__).parent.parent / "assets" / "example.txt"
# with open(test_file, "r") as f:
# content = f.read()
content = """
In darkest hours of coding's fierce domain,
Where bugs lurk deep in shadows, hard to find,
Each error message brings fresh waves of pain,
As stack traces drive madness through the mind.
Through endless loops of print statements we wade,
Debug flags raised like torches in the night,
While segfaults mock each careful plan we made,
And race conditions laugh at our plight.
O rubber duck, my silent debugging friend,
Your plastic gaze holds wisdom yet untold,
As line by line we trace paths without end,
Seeking that elusive bug of gold.
Yet hope remains while coffee still flows strong,
Through debugging hell, we'll debug right from wrong.
""".strip()
return Document(
external_id="alice_ch1",
owner={"id": "test_user", "name": "Test User"},
content_type="text/plain",
system_metadata={
"content": content,
"title": "Alice in Wonderland - Chapter 1",
"source": "test_document",
},
)
@pytest.fixture
def llama_cache():
"""Create a LlamaCache instance with the test document."""
doc = get_test_document()
cache = LlamaCache(
name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc]
)
return cache
def test_basic_rag_capabilities(llama_cache):
"""Test that the cache can answer basic questions about the document content."""
# Test question about whether ingestion is actually happening
response = llama_cache.query(
"Summarize the content of the document. Please respond in a single sentence. Summary: "
)
assert isinstance(response, CompletionResponse)
# assert "alice" in response.completion.lower()
# # Test question about a specific detail
# response = llama_cache.query(
# "What did Alice see the White Rabbit do with its watch? Please respond in a single sentence. Answer: "
# )
# assert isinstance(response, CompletionResponse)
# # assert "waistcoat-pocket" in response.completion.lower() or "looked at it" in response.completion.lower()
# # Test question about character description
# response = llama_cache.query(
# "How did Alice's size change during the story? Please respond in a single sentence. Answer: "
# )
# assert isinstance(response, CompletionResponse)
# # assert any(phrase in response.completion.lower() for phrase in ["grew larger", "grew smaller", "nine feet", "telescope"])
# # Test question about plot elements
# response = llama_cache.query(
# "What was written on the bottle Alice found? Please respond in a single sentence. Answer: "
# )
# assert isinstance(response, CompletionResponse)
# # assert "drink me" in response.completion.lower()
# def test_cache_memory_persistence(llama_cache):
# """Test that the cache maintains context across multiple queries."""
# # First query to establish context
# llama_cache.query(
# "What was Alice doing before she saw the White Rabbit? Please respond in a single sentence. Answer: "
# )
# # Follow-up query that requires remembering previous context
# response = llama_cache.query(
# "What book was her sister reading? Please respond in a single sentence. Answer: "
# )
# assert isinstance(response, CompletionResponse)
# # assert "no pictures" in response.completion.lower() or "conversations" in response.completion.lower()
def test_adding_new_documents(llama_cache):
"""Test that the cache can incorporate new documents into its knowledge."""
# Create a new document with additional content
new_doc = Document(
external_id="alice_ch2",
owner={"id": "test_user", "name": "Test User"},
content_type="text/plain",
system_metadata={
"content": "Alice found herself in a pool of tears. She met a Mouse swimming in the pool.",
"title": "Alice in Wonderland - Additional Content",
"source": "test_document",
},
)
# Add the new document
success = llama_cache.add_docs([new_doc])
assert success
# Query about the new content
response = llama_cache.query(
"What did Alice find in the pool of tears? Please respond in a single sentence. Answer: "
)
assert isinstance(response, CompletionResponse)
assert "mouse" in response.completion.lower()
def test_cache_state_persistence():
"""Test that the cache state can be saved and loaded."""
# Create initial cache
doc = get_test_document()
original_cache = LlamaCache(
name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc]
)
# Get the state
state_bytes = original_cache.saveable_state
# Save state bytes to temporary file
import tempfile
import os
import pickle
with tempfile.TemporaryDirectory() as temp_dir:
cache_file = os.path.join(temp_dir, "cache.pkl")
# Save to file
with open(cache_file, "wb") as f:
pickle.dump(state_bytes, f)
# Load from file
with open(cache_file, "rb") as f:
loaded_state_bytes = pickle.load(f)
# # Verify state bytes match
# assert state_bytes == loaded_state_bytes
# state_bytes = loaded_state_bytes # Use loaded bytes for rest of test
# Create new cache from state
loaded_cache = LlamaCache.from_bytes(
name="test_cache",
cache_bytes=loaded_state_bytes,
metadata={
"model": TEST_MODEL,
"model_file": TEST_GGUF_FILE,
"filters": {},
"docs": [doc.model_dump_json()],
},
)
# Verify the loaded cache works
response = loaded_cache.query(
"Summarize the content of the document. Please respond in a single sentence. Summary: "
)
assert isinstance(response, CompletionResponse)
assert "coding" in response.completion.lower() or "debug" in response.completion.lower()
# assert "bottle" in response.completion.lower() and "drink me" in response.completion.lower()

View File

@ -37,6 +37,8 @@ def sample_chunks() -> List[DocumentChunk]:
def reranker():
"""Fixture to create and reuse a flag reranker instance"""
settings = get_settings()
if not settings.USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
return FlagReranker(
model_name=settings.RERANKER_MODEL,
device=settings.RERANKER_DEVICE,
@ -49,6 +51,8 @@ def reranker():
@pytest.mark.asyncio
async def test_reranker_relevance(reranker, sample_chunks):
"""Test that reranker improves relevance for programming-related query"""
if not get_settings().USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
print("\n=== Testing Reranker Relevance ===")
query = "What is Python programming language?"
@ -70,6 +74,8 @@ async def test_reranker_relevance(reranker, sample_chunks):
@pytest.mark.asyncio
async def test_reranker_score_distribution(reranker, sample_chunks):
"""Test that reranker produces reasonable score distribution"""
if not get_settings().USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
print("\n=== Testing Score Distribution ===")
query = "Tell me about machine learning and data science"
@ -95,6 +101,8 @@ async def test_reranker_score_distribution(reranker, sample_chunks):
@pytest.mark.asyncio
async def test_reranker_batch_scoring(reranker):
"""Test that reranker can handle multiple queries/passages efficiently"""
if not get_settings().USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
print("\n=== Testing Batch Scoring ===")
texts = [
"Python is a programming language",
@ -118,6 +126,8 @@ async def test_reranker_batch_scoring(reranker):
@pytest.mark.asyncio
async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
"""Test reranker behavior with empty or edge case inputs"""
if not get_settings().USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
print("\n=== Testing Edge Cases ===")
# Empty chunks list
@ -147,6 +157,8 @@ async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
@pytest.mark.asyncio
async def test_reranker_consistency(reranker, sample_chunks):
"""Test that reranker produces consistent results for same input"""
if not get_settings().USE_RERANKING:
pytest.skip("Reranker is disabled in settings")
print("\n=== Testing Consistency ===")
query = "What is Python programming?"

View File

@ -9,6 +9,8 @@ WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y \
gcc \
g++ \
cmake \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
@ -16,6 +18,9 @@ RUN apt-get update && apt-get install -y \
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Download NLTK data
RUN python -m nltk.downloader -d /usr/local/share/nltk_data punkt averaged_perceptron_tagger
# Production stage
FROM python:3.12.5-slim
@ -31,11 +36,17 @@ RUN apt-get update && apt-get install -y \
tesseract-ocr \
postgresql-client \
poppler-utils \
gcc \
g++ \
cmake \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy installed packages from builder
COPY --from=builder /root/.local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /root/.local/bin /usr/local/bin
# Copy NLTK data from builder
COPY --from=builder /usr/local/share/nltk_data /usr/local/share/nltk_data
# Create necessary directories
RUN mkdir -p storage logs

View File

@ -31,7 +31,15 @@ CREATE TABLE IF NOT EXISTS vector_embeddings (
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
-- Create caches table
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSON NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
-- Create vector index
CREATE INDEX IF NOT EXISTS vector_idx
CREATE INDEX IF NOT EXISTS vector_idx
ON vector_embeddings USING ivfflat (embedding vector_l2_ops)
WITH (lists = 100);
WITH (lists = 100);

View File

@ -241,12 +241,23 @@ def setup_postgres():
# Import and create all tables
from core.database.postgres_database import Base
from core.vector_store.pgvector_store import Base as VectorBase
# Create regular tables first
await conn.run_sync(Base.metadata.create_all)
LOGGER.info("Created base PostgreSQL tables")
# Create caches table
create_caches_table = """
CREATE TABLE IF NOT EXISTS caches (
name TEXT PRIMARY KEY,
metadata JSON NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
await conn.execute(text(create_caches_table))
LOGGER.info("Created caches table")
# Get vector dimensions from config
dimensions = CONFIG["embedding"]["dimensions"]
@ -278,8 +289,8 @@ def setup_postgres():
LOGGER.info("Created vector_embeddings table with vector column")
# Create the vector index
index_sql = f"""
CREATE INDEX vector_idx
index_sql = """
CREATE INDEX vector_idx
ON vector_embeddings USING ivfflat (embedding vector_l2_ops)
WITH (lists = 100);
"""

View File

@ -293,3 +293,4 @@ zlib-state==0.1.9
pgvector==0.2.5
psycopg[binary]==3.1.18
psycopg-binary==3.1.18
llama-cpp-python==0.3.5

View File

@ -13,4 +13,4 @@ __all__ = [
"IngestTextRequest",
]
__version__ = "0.1.8"
__version__ = "0.2.0"

View File

@ -16,6 +16,30 @@ from .models import (
)
class AsyncCache:
def __init__(self, db: "AsyncDataBridge", name: str):
self._db = db
self._name = name
async def update(self) -> bool:
response = await self._db._request("POST", f"cache/{self._name}/update")
return response.get("success", False)
async def add_docs(self, docs: List[str]) -> bool:
response = await self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs})
return response.get("success", False)
async def query(
self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None
) -> str:
response = await self._db._request(
"POST",
f"cache/{self._name}/query",
{"query": query, "max_tokens": max_tokens, "temperature": temperature},
)
return CompletionResponse(**response)
class AsyncDataBridge:
"""
DataBridge client for document operations.
@ -345,6 +369,72 @@ class AsyncDataBridge:
response = await self._request("GET", f"documents/{document_id}")
return Document(**response)
async def create_cache(
self,
name: str,
model: str,
gguf_file: str,
filters: Optional[Dict[str, Any]] = None,
docs: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Create a new cache with specified configuration.
Args:
name: Name of the cache to create
model: Name of the model to use (e.g. "llama2")
gguf_file: Name of the GGUF file to use for the model
filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided.
docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters.
Returns:
Dict[str, Any]: Created cache configuration
Example:
```python
# This will include both:
# 1. Any documents with category="programming"
# 2. The specific documents "doc1" and "doc2" (regardless of their category)
cache = await db.create_cache(
name="programming_cache",
model="llama2",
gguf_file="llama-2-7b-chat.Q4_K_M.gguf",
filters={"category": "programming"},
docs=["doc1", "doc2"]
)
```
"""
request = {
"name": name,
"model": model,
"gguf_file": gguf_file,
"filters": filters,
"docs": docs,
}
response = await self._request("POST", "cache/create", request)
return response
async def get_cache(self, name: str) -> AsyncCache:
"""
Get a cache by name.
Args:
name: Name of the cache to retrieve
Returns:
cache: A cache object that is used to interact with the cache.
Example:
```python
cache = await db.get_cache("programming_cache")
```
"""
response = await self._request("GET", f"cache/{name}")
if response.get("exists", False):
return AsyncCache(self, name)
raise ValueError(f"Cache '{name}' not found")
async def close(self):
"""Close the HTTP client"""
await self._client.aclose()

View File

@ -16,6 +16,31 @@ from .models import (
)
class Cache:
def __init__(self, db: "DataBridge", name: str):
self._db = db
self._name = name
def update(self) -> bool:
response = self._db._request("POST", f"cache/{self._name}/update")
return response.get("success", False)
def add_docs(self, docs: List[str]) -> bool:
response = self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs})
return response.get("success", False)
def query(
self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None
) -> CompletionResponse:
response = self._db._request(
"POST",
f"cache/{self._name}/query",
params={"query": query, "max_tokens": max_tokens, "temperature": temperature},
data="",
)
return CompletionResponse(**response)
class DataBridge:
"""
DataBridge client for document operations.
@ -71,6 +96,7 @@ class DataBridge:
endpoint: str,
data: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Make HTTP request"""
headers = {}
@ -88,6 +114,7 @@ class DataBridge:
data=data if files else None,
headers=headers,
timeout=self._timeout,
params=params,
)
response.raise_for_status()
return response.json()
@ -334,6 +361,70 @@ class DataBridge:
response = self._request("GET", f"documents/{document_id}")
return Document(**response)
def create_cache(
self,
name: str,
model: str,
gguf_file: str,
filters: Optional[Dict[str, Any]] = None,
docs: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Create a new cache with specified configuration.
Args:
name: Name of the cache to create
model: Name of the model to use (e.g. "llama2")
gguf_file: Name of the GGUF file to use for the model
filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided.
docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters.
Returns:
Dict[str, Any]: Created cache configuration
Example:
```python
# This will include both:
# 1. Any documents with category="programming"
# 2. The specific documents "doc1" and "doc2" (regardless of their category)
cache = db.create_cache(
name="programming_cache",
model="llama2",
gguf_file="llama-2-7b-chat.Q4_K_M.gguf",
filters={"category": "programming"},
docs=["doc1", "doc2"]
)
```
"""
# Build query parameters for name, model and gguf_file
params = {"name": name, "model": model, "gguf_file": gguf_file}
# Build request body for filters and docs
request = {"filters": filters, "docs": docs}
response = self._request("POST", "cache/create", request, params=params)
return response
def get_cache(self, name: str) -> Cache:
"""
Get a cache by name.
Args:
name: Name of the cache to retrieve
Returns:
cache: A cache object that is used to interact with the cache.
Example:
```python
cache = db.get_cache("programming_cache")
```
"""
response = self._request("GET", f"cache/{name}")
if response.get("exists", False):
return Cache(self, name)
raise ValueError(f"Cache '{name}' not found")
def close(self):
"""Close the HTTP session"""
self._session.close()

View File

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "databridge-client"
version = "0.1.8"
version = "0.2.0"
authors = [
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" },
]

View File

@ -118,11 +118,59 @@ class DB:
doc = self._client.get_document(document_id)
return doc.model_dump()
def create_cache(
self,
name: str,
model: str,
gguf_file: str,
filters: dict = None,
docs: list = None,
) -> dict:
"""Create a new cache with specified configuration"""
response = self._client.create_cache(
name=name,
model=model,
gguf_file=gguf_file,
filters=filters or {},
docs=docs,
)
return response
def get_cache(self, name: str) -> "Cache":
"""Get a cache by name"""
return self._client.get_cache(name)
def close(self):
"""Close the client connection"""
self._client.close()
class Cache:
def __init__(self, db: DB, name: str):
self._db = db
self._name = name
self._client_cache = db._client.get_cache(name)
def update(self) -> bool:
"""Update the cache"""
return self._client_cache.update()
def add_docs(self, docs: list) -> bool:
"""Add documents to the cache"""
return self._client_cache.add_docs(docs)
def query(
self, query: str, max_tokens: int = None, temperature: float = None
) -> dict:
"""Query the cache"""
response = self._client_cache.query(
query=query,
max_tokens=max_tokens,
temperature=temperature,
)
return response.model_dump()
if __name__ == "__main__":
uri = sys.argv[1] if len(sys.argv) > 1 else None
db = DB(uri)