mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add support for cache-augmented-generation (#30)
This commit is contained in:
parent
4f0cf62008
commit
d124e6aa0d
135
core/api.py
135
core/api.py
@ -1,5 +1,7 @@
|
||||
import json
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -30,6 +32,7 @@ from core.embedding.openai_embedding_model import OpenAIEmbeddingModel
|
||||
from core.completion.ollama_completion import OllamaCompletionModel
|
||||
from core.parser.contextual_parser import ContextualParser
|
||||
from core.reranker.flag_reranker import FlagReranker
|
||||
from core.cache.llama_cache_factory import LlamaCacheFactory
|
||||
import tomli
|
||||
|
||||
# Initialize FastAPI app
|
||||
@ -214,6 +217,9 @@ if settings.USE_RERANKING:
|
||||
case _:
|
||||
raise ValueError(f"Unsupported reranker provider: {settings.RERANKER_PROVIDER}")
|
||||
|
||||
# Initialize cache factory
|
||||
cache_factory = LlamaCacheFactory(Path(settings.STORAGE_PATH))
|
||||
|
||||
# Initialize document service with configured components
|
||||
document_service = DocumentService(
|
||||
storage=storage,
|
||||
@ -223,6 +229,7 @@ document_service = DocumentService(
|
||||
completion_model=completion_model,
|
||||
parser=parser,
|
||||
reranker=reranker,
|
||||
cache_factory=cache_factory,
|
||||
)
|
||||
|
||||
|
||||
@ -462,6 +469,134 @@ async def get_recent_usage(
|
||||
]
|
||||
|
||||
|
||||
# Cache endpoints
|
||||
@app.post("/cache/create")
|
||||
async def create_cache(
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[List[str]] = None,
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new cache with specified configuration."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="create_cache",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"name": name,
|
||||
"model": model,
|
||||
"gguf_file": gguf_file,
|
||||
"filters": filters,
|
||||
"docs": docs,
|
||||
},
|
||||
):
|
||||
filter_docs = set(await document_service.db.get_documents(auth, filters=filters))
|
||||
additional_docs = (
|
||||
{
|
||||
await document_service.db.get_document(document_id=doc_id, auth=auth)
|
||||
for doc_id in docs
|
||||
}
|
||||
if docs
|
||||
else set()
|
||||
)
|
||||
docs_to_add = list(filter_docs.union(additional_docs))
|
||||
if not docs_to_add:
|
||||
raise HTTPException(status_code=400, detail="No documents to add to cache")
|
||||
response = await document_service.create_cache(
|
||||
name, model, gguf_file, docs_to_add, filters
|
||||
)
|
||||
return response
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/cache/{name}")
|
||||
async def get_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, Any]:
|
||||
"""Get cache configuration by name."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="get_cache",
|
||||
user_id=auth.entity_id,
|
||||
metadata={"name": name},
|
||||
):
|
||||
exists = await document_service.load_cache(name)
|
||||
return {"exists": exists}
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/cache/{name}/update")
|
||||
async def update_cache(name: str, auth: AuthContext = Depends(verify_token)) -> Dict[str, bool]:
|
||||
"""Update cache with new documents matching its filter."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="update_cache",
|
||||
user_id=auth.entity_id,
|
||||
metadata={"name": name},
|
||||
):
|
||||
if name not in document_service.active_caches:
|
||||
exists = await document_service.load_cache(name)
|
||||
if not exists:
|
||||
raise HTTPException(status_code=404, detail=f"Cache '{name}' not found")
|
||||
cache = document_service.active_caches[name]
|
||||
docs = await document_service.db.get_documents(auth, filters=cache.filters)
|
||||
docs_to_add = [doc for doc in docs if doc.id not in cache.docs]
|
||||
return cache.add_docs(docs_to_add)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/cache/{name}/add_docs")
|
||||
async def add_docs_to_cache(
|
||||
name: str, docs: List[str], auth: AuthContext = Depends(verify_token)
|
||||
) -> Dict[str, bool]:
|
||||
"""Add specific documents to the cache."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="add_docs_to_cache",
|
||||
user_id=auth.entity_id,
|
||||
metadata={"name": name, "docs": docs},
|
||||
):
|
||||
cache = document_service.active_caches[name]
|
||||
docs_to_add = [
|
||||
await document_service.db.get_document(doc_id, auth)
|
||||
for doc_id in docs
|
||||
if doc_id not in cache.docs
|
||||
]
|
||||
return cache.add_docs(docs_to_add)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/cache/{name}/query")
|
||||
async def query_cache(
|
||||
name: str,
|
||||
query: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
auth: AuthContext = Depends(verify_token),
|
||||
) -> CompletionResponse:
|
||||
"""Query the cache with a prompt."""
|
||||
try:
|
||||
async with telemetry.track_operation(
|
||||
operation_type="query_cache",
|
||||
user_id=auth.entity_id,
|
||||
metadata={
|
||||
"name": name,
|
||||
"query": query,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
):
|
||||
cache = document_service.active_caches[name]
|
||||
print(f"Cache state: {cache.state.n_tokens}", file=sys.stderr)
|
||||
return cache.query(query) # , max_tokens, temperature)
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/local/generate_uri", include_in_schema=True)
|
||||
async def generate_local_uri(
|
||||
name: str = Form("admin"),
|
||||
|
68
core/cache/base_cache.py
vendored
Normal file
68
core/cache/base_cache.py
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from core.models.completion import CompletionResponse
|
||||
from core.models.documents import Document
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
"""Base class for cache implementations.
|
||||
|
||||
This class defines the interface for cache implementations that support
|
||||
document ingestion and cache-augmented querying.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, model: str, gguf_file: str, filters: Dict[str, Any], docs: List[Document]
|
||||
):
|
||||
"""Initialize the cache with the given parameters.
|
||||
|
||||
Args:
|
||||
name: Name of the cache instance
|
||||
model: Model identifier
|
||||
gguf_file: Path to the GGUF model file
|
||||
filters: Filters used to create the cache context
|
||||
docs: Initial documents to ingest into the cache
|
||||
"""
|
||||
self.name = name
|
||||
self.filters = filters
|
||||
self.docs = [] # List of document IDs that have been ingested
|
||||
self._initialize(model, gguf_file, docs)
|
||||
|
||||
@abstractmethod
|
||||
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
|
||||
"""Internal initialization method to be implemented by subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def add_docs(self, docs: List[Document]) -> bool:
|
||||
"""Add documents to the cache.
|
||||
|
||||
Args:
|
||||
docs: List of documents to add to the cache
|
||||
|
||||
Returns:
|
||||
bool: True if documents were successfully added
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def query(self, query: str) -> CompletionResponse:
|
||||
"""Query the cache for relevant documents and generate a response.
|
||||
|
||||
Args:
|
||||
query: Query string to search for relevant documents
|
||||
|
||||
Returns:
|
||||
CompletionResponse: Generated response based on cached context
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def saveable_state(self) -> bytes:
|
||||
"""Get the saveable state of the cache as bytes.
|
||||
|
||||
Returns:
|
||||
bytes: Serialized state that can be used to restore the cache
|
||||
"""
|
||||
pass
|
64
core/cache/base_cache_factory.py
vendored
Normal file
64
core/cache/base_cache_factory.py
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class BaseCacheFactory(ABC):
|
||||
"""Abstract base factory for creating and loading caches."""
|
||||
|
||||
def __init__(self, storage_path: Path):
|
||||
"""Initialize the cache factory.
|
||||
|
||||
Args:
|
||||
storage_path: Base path for storing cache files
|
||||
"""
|
||||
self.storage_path = storage_path
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@abstractmethod
|
||||
def create_new_cache(
|
||||
self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]
|
||||
) -> BaseCache:
|
||||
"""Create a new cache instance.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
model: Name/type of the model to use
|
||||
model_file: Path or identifier for the model file
|
||||
**kwargs: Additional arguments for cache creation
|
||||
|
||||
Returns:
|
||||
BaseCache: The created cache instance
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_cache_from_bytes(
|
||||
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
|
||||
) -> BaseCache:
|
||||
"""Load a cache from its serialized bytes.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
cache_bytes: Serialized cache data
|
||||
metadata: Cache metadata including model info
|
||||
**kwargs: Additional arguments for cache loading
|
||||
|
||||
Returns:
|
||||
BaseCache: The loaded cache instance
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_cache_path(self, name: str) -> Path:
|
||||
"""Get the storage path for a cache.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
|
||||
Returns:
|
||||
Path: Directory path for the cache
|
||||
"""
|
||||
path = self.storage_path / name
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
285
core/cache/hf_cache.py
vendored
Normal file
285
core/cache/hf_cache.py
vendored
Normal file
@ -0,0 +1,285 @@
|
||||
# hugging face cache implementation.
|
||||
|
||||
from core.cache.base_cache import BaseCache
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from core.models.completion import CompletionRequest, CompletionResponse
|
||||
|
||||
|
||||
class HuggingFaceCache(BaseCache):
|
||||
"""Hugging Face Cache implementation for cache-augmented generation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_path: Path,
|
||||
model_name: str = "distilgpt2",
|
||||
device: str = "cpu",
|
||||
default_max_new_tokens: int = 100,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
"""Initialize the HuggingFace cache.
|
||||
|
||||
Args:
|
||||
cache_path: Path to store cache files
|
||||
model_name: Name of the HuggingFace model to use
|
||||
device: Device to run the model on (e.g. "cpu", "cuda", "mps")
|
||||
default_max_new_tokens: Default maximum number of new tokens to generate
|
||||
use_fp16: Whether to use FP16 precision
|
||||
"""
|
||||
super().__init__()
|
||||
self.cache_path = cache_path
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.default_max_new_tokens = default_max_new_tokens
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
# Initialize tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Configure model loading based on device
|
||||
model_kwargs = {"low_cpu_mem_usage": True}
|
||||
|
||||
if device == "cpu":
|
||||
# For CPU, use standard loading
|
||||
model_kwargs.update({"torch_dtype": torch.float32})
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device)
|
||||
else:
|
||||
# For GPU/MPS, use automatic device mapping and optional FP16
|
||||
model_kwargs.update(
|
||||
{"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32}
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
|
||||
self.kv_cache = None
|
||||
self.origin_len = None
|
||||
|
||||
def get_kv_cache(self, prompt: str) -> DynamicCache:
|
||||
"""Build KV cache from prompt"""
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
||||
cache = DynamicCache()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True)
|
||||
return cache
|
||||
|
||||
def clean_up_cache(self, cache: DynamicCache, origin_len: int):
|
||||
"""Clean up cache by removing appended tokens"""
|
||||
for i in range(len(cache.key_cache)):
|
||||
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
|
||||
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
|
||||
|
||||
def generate(
|
||||
self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Generate text using the model and cache"""
|
||||
device = next(self.model.parameters()).device
|
||||
origin_len = input_ids.shape[-1]
|
||||
input_ids = input_ids.to(device)
|
||||
output_ids = input_ids.clone()
|
||||
next_token = input_ids
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(max_new_tokens or self.default_max_new_tokens):
|
||||
out = self.model(
|
||||
input_ids=next_token, past_key_values=past_key_values, use_cache=True
|
||||
)
|
||||
logits = out.logits[:, -1, :]
|
||||
token = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
output_ids = torch.cat([output_ids, token], dim=-1)
|
||||
past_key_values = out.past_key_values
|
||||
next_token = token.to(device)
|
||||
|
||||
if (
|
||||
self.model.config.eos_token_id is not None
|
||||
and token.item() == self.model.config.eos_token_id
|
||||
):
|
||||
break
|
||||
|
||||
return output_ids[:, origin_len:]
|
||||
|
||||
async def ingest(self, docs: List[str]) -> bool:
|
||||
"""Ingest documents into cache"""
|
||||
try:
|
||||
# Create system prompt with documents
|
||||
system_prompt = f"""
|
||||
<|system|>
|
||||
You are an assistant who provides concise factual answers.
|
||||
<|user|>
|
||||
Context:
|
||||
{' '.join(docs)}
|
||||
Question:
|
||||
""".strip()
|
||||
|
||||
# Build the cache
|
||||
input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device)
|
||||
self.kv_cache = DynamicCache()
|
||||
|
||||
with torch.no_grad():
|
||||
# First run to get the cache shape
|
||||
outputs = self.model(input_ids=input_ids, use_cache=True)
|
||||
# Initialize cache with empty tensors of the right shape
|
||||
n_layers = len(outputs.past_key_values)
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Handle different model architectures
|
||||
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
# Models with grouped query attention (GQA) like Llama
|
||||
n_kv_heads = self.model.config.num_key_value_heads
|
||||
head_dim = self.model.config.head_dim
|
||||
elif hasattr(self.model.config, "n_head"):
|
||||
# GPT-style models
|
||||
n_kv_heads = self.model.config.n_head
|
||||
head_dim = self.model.config.n_embd // self.model.config.n_head
|
||||
elif hasattr(self.model.config, "num_attention_heads"):
|
||||
# OPT-style models
|
||||
n_kv_heads = self.model.config.num_attention_heads
|
||||
head_dim = (
|
||||
self.model.config.hidden_size // self.model.config.num_attention_heads
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model architecture: {self.model.config.model_type}"
|
||||
)
|
||||
|
||||
seq_len = input_ids.shape[1]
|
||||
|
||||
for i in range(n_layers):
|
||||
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
|
||||
value_shape = key_shape
|
||||
self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
|
||||
self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
|
||||
|
||||
# Now run with the initialized cache
|
||||
outputs = self.model(
|
||||
input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True
|
||||
)
|
||||
# Update cache with actual values
|
||||
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
|
||||
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
|
||||
self.origin_len = self.kv_cache.key_cache[0].shape[-2]
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error ingesting documents: {e}")
|
||||
return False
|
||||
|
||||
async def update(self, new_doc: str) -> bool:
|
||||
"""Update cache with new document"""
|
||||
try:
|
||||
if self.kv_cache is None:
|
||||
return await self.ingest([new_doc])
|
||||
|
||||
# Clean up existing cache
|
||||
self.clean_up_cache(self.kv_cache, self.origin_len)
|
||||
|
||||
# Add new document to cache
|
||||
input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# First run to get the cache shape
|
||||
outputs = self.model(input_ids=input_ids, use_cache=True)
|
||||
# Initialize cache with empty tensors of the right shape
|
||||
n_layers = len(outputs.past_key_values)
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
# Handle different model architectures
|
||||
if hasattr(self.model.config, "num_key_value_heads"):
|
||||
# Models with grouped query attention (GQA) like Llama
|
||||
n_kv_heads = self.model.config.num_key_value_heads
|
||||
head_dim = self.model.config.head_dim
|
||||
elif hasattr(self.model.config, "n_head"):
|
||||
# GPT-style models
|
||||
n_kv_heads = self.model.config.n_head
|
||||
head_dim = self.model.config.n_embd // self.model.config.n_head
|
||||
elif hasattr(self.model.config, "num_attention_heads"):
|
||||
# OPT-style models
|
||||
n_kv_heads = self.model.config.num_attention_heads
|
||||
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
||||
else:
|
||||
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
|
||||
|
||||
seq_len = input_ids.shape[1]
|
||||
|
||||
# Create a new cache for the update
|
||||
new_cache = DynamicCache()
|
||||
for i in range(n_layers):
|
||||
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
|
||||
value_shape = key_shape
|
||||
new_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
|
||||
new_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
|
||||
|
||||
# Run with the initialized cache
|
||||
outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True)
|
||||
# Update cache with actual values
|
||||
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
|
||||
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error updating cache: {e}")
|
||||
return False
|
||||
|
||||
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
||||
"""Generate completion using cache-augmented generation"""
|
||||
try:
|
||||
if self.kv_cache is None:
|
||||
raise ValueError("Cache not initialized. Please ingest documents first.")
|
||||
|
||||
# Clean up cache
|
||||
self.clean_up_cache(self.kv_cache, self.origin_len)
|
||||
|
||||
# Generate completion
|
||||
input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to(
|
||||
self.device
|
||||
)
|
||||
gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens)
|
||||
completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
|
||||
|
||||
# Calculate token usage
|
||||
usage = {
|
||||
"prompt_tokens": len(input_ids[0]),
|
||||
"completion_tokens": len(gen_ids[0]),
|
||||
"total_tokens": len(input_ids[0]) + len(gen_ids[0]),
|
||||
}
|
||||
|
||||
return CompletionResponse(completion=completion, usage=usage)
|
||||
except Exception as e:
|
||||
print(f"Error generating completion: {e}")
|
||||
return CompletionResponse(
|
||||
completion=f"Error: {str(e)}",
|
||||
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
)
|
||||
|
||||
def save_cache(self) -> Path:
|
||||
"""Save the KV cache to disk"""
|
||||
if self.kv_cache is None:
|
||||
raise ValueError("No cache to save")
|
||||
|
||||
cache_dir = self.cache_path / "kv_cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save key and value caches
|
||||
cache_data = {
|
||||
"key_cache": self.kv_cache.key_cache,
|
||||
"value_cache": self.kv_cache.value_cache,
|
||||
"origin_len": self.origin_len,
|
||||
}
|
||||
cache_path = cache_dir / "cache.pt"
|
||||
torch.save(cache_data, cache_path)
|
||||
return cache_path
|
||||
|
||||
def load_cache(self, cache_path: Union[str, Path]) -> None:
|
||||
"""Load KV cache from disk"""
|
||||
cache_path = Path(cache_path)
|
||||
if not cache_path.exists():
|
||||
raise FileNotFoundError(f"Cache file not found at {cache_path}")
|
||||
|
||||
cache_data = torch.load(cache_path, map_location=self.device)
|
||||
|
||||
self.kv_cache = DynamicCache()
|
||||
self.kv_cache.key_cache = cache_data["key_cache"]
|
||||
self.kv_cache.value_cache = cache_data["value_cache"]
|
||||
self.origin_len = cache_data["origin_len"]
|
196
core/cache/llama_cache.py
vendored
Normal file
196
core/cache/llama_cache.py
vendored
Normal file
@ -0,0 +1,196 @@
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
from core.cache.base_cache import BaseCache
|
||||
from typing import Dict, Any, List
|
||||
from core.models.completion import CompletionResponse
|
||||
from core.models.documents import Document
|
||||
from llama_cpp import Llama
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INITIAL_SYSTEM_PROMPT = """<|im_start|>system
|
||||
You are a helpful AI assistant with access to provided documents. Your role is to:
|
||||
1. Answer questions accurately based on the documents provided
|
||||
2. Stay focused on the document content and avoid speculation
|
||||
3. Admit when you don't have enough information to answer
|
||||
4. Be clear and concise in your responses
|
||||
5. Use direct quotes from documents when relevant
|
||||
|
||||
Provided documents: {documents}
|
||||
<|im_end|>
|
||||
""".strip()
|
||||
|
||||
ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system
|
||||
I'm adding some additional documents for your reference:
|
||||
{documents}
|
||||
|
||||
Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses.
|
||||
<|im_end|>
|
||||
""".strip()
|
||||
|
||||
QUERY_PROMPT = """<|im_start|>user
|
||||
{query}
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
""".strip()
|
||||
|
||||
|
||||
class LlamaCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
filters: Dict[str, Any],
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
):
|
||||
logger.info(f"Initializing LlamaCache with name={name}, model={model}")
|
||||
# cache related
|
||||
self.name = name
|
||||
self.model = model
|
||||
self.filters = filters
|
||||
self.docs = docs
|
||||
|
||||
# llama specific
|
||||
self.gguf_file = gguf_file
|
||||
self.n_gpu_layers = kwargs.get("n_gpu_layers", -1)
|
||||
logger.info(f"Using {self.n_gpu_layers} GPU layers")
|
||||
|
||||
# late init (when we call _initialize)
|
||||
self.llama = None
|
||||
self.state = None
|
||||
self.cached_tokens = 0
|
||||
|
||||
self._initialize(model, gguf_file, docs)
|
||||
logger.info("LlamaCache initialization complete")
|
||||
|
||||
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
|
||||
logger.info(f"Loading Llama model from {model} with file {gguf_file}")
|
||||
try:
|
||||
# Set a reasonable default context size (32K tokens)
|
||||
default_ctx_size = 32768
|
||||
|
||||
self.llama = Llama.from_pretrained(
|
||||
repo_id=model,
|
||||
filename=gguf_file,
|
||||
n_gpu_layers=self.n_gpu_layers,
|
||||
n_ctx=default_ctx_size,
|
||||
verbose=False, # Enable verbose mode for better error reporting
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Format and tokenize system prompt
|
||||
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
|
||||
system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents)
|
||||
logger.info(f"Built system prompt: {system_prompt[:200]}...")
|
||||
|
||||
try:
|
||||
tokens = self.llama.tokenize(system_prompt.encode())
|
||||
logger.info(f"System prompt tokenized to {len(tokens)} tokens")
|
||||
|
||||
# Process tokens to build KV cache
|
||||
logger.info("Evaluating system prompt")
|
||||
self.llama.eval(tokens)
|
||||
logger.info("Saving initial KV cache state")
|
||||
self.state = self.llama.save_state()
|
||||
self.cached_tokens = len(tokens)
|
||||
logger.info(f"Initial KV cache built with {self.cached_tokens} tokens")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during prompt processing: {str(e)}")
|
||||
raise ValueError(f"Failed to process system prompt: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Llama model: {str(e)}")
|
||||
raise ValueError(f"Failed to initialize Llama model: {str(e)}")
|
||||
|
||||
def add_docs(self, docs: List[Document]) -> bool:
|
||||
logger.info(f"Adding {len(docs)} new documents to cache")
|
||||
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
|
||||
system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents)
|
||||
|
||||
# Tokenize and process
|
||||
new_tokens = self.llama.tokenize(system_prompt.encode())
|
||||
self.llama.eval(new_tokens)
|
||||
self.state = self.llama.save_state()
|
||||
self.cached_tokens += len(new_tokens)
|
||||
logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}")
|
||||
return True
|
||||
|
||||
def query(self, query: str) -> CompletionResponse:
|
||||
# Format query with proper chat template
|
||||
formatted_query = QUERY_PROMPT.format(query=query)
|
||||
logger.info(f"Processing query: {formatted_query}")
|
||||
|
||||
# Reset and load cached state
|
||||
self.llama.reset()
|
||||
self.llama.load_state(self.state)
|
||||
logger.info(f"Loaded state with {self.state.n_tokens} tokens")
|
||||
# print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr)
|
||||
# Tokenize and process query
|
||||
query_tokens = self.llama.tokenize(formatted_query.encode())
|
||||
self.llama.eval(query_tokens)
|
||||
logger.info(f"Evaluated query tokens: {query_tokens}")
|
||||
# print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr)
|
||||
|
||||
# Generate response
|
||||
output_tokens = []
|
||||
for token in self.llama.generate(tokens=[], reset=False):
|
||||
output_tokens.append(token)
|
||||
# Stop generation when EOT token is encountered
|
||||
if token == self.llama.token_eos():
|
||||
break
|
||||
|
||||
# Decode and return
|
||||
completion = self.llama.detokenize(output_tokens).decode()
|
||||
logger.info(f"Generated completion: {completion}")
|
||||
|
||||
return CompletionResponse(
|
||||
completion=completion,
|
||||
usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)},
|
||||
)
|
||||
|
||||
@property
|
||||
def saveable_state(self) -> bytes:
|
||||
logger.info("Serializing cache state")
|
||||
state_bytes = pickle.dumps(self.state)
|
||||
logger.info(f"Serialized state size: {len(state_bytes)} bytes")
|
||||
return state_bytes
|
||||
|
||||
@classmethod
|
||||
def from_bytes(
|
||||
cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs
|
||||
) -> "LlamaCache":
|
||||
"""Load a cache from its serialized state.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
cache_bytes: Pickled state bytes
|
||||
metadata: Cache metadata including model info
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
LlamaCache: Loaded cache instance
|
||||
"""
|
||||
logger.info(f"Loading cache from bytes with name={name}")
|
||||
logger.info(f"Cache metadata: {metadata}")
|
||||
# Create new instance with metadata
|
||||
# logger.info(f"Docs: {metadata['docs']}")
|
||||
docs = [json.loads(doc) for doc in metadata["docs"]]
|
||||
# time.sleep(10)
|
||||
cache = cls(
|
||||
name=name,
|
||||
model=metadata["model"],
|
||||
gguf_file=metadata["model_file"],
|
||||
filters=metadata["filters"],
|
||||
docs=[Document(**doc) for doc in docs],
|
||||
)
|
||||
|
||||
# Load the saved state
|
||||
logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes")
|
||||
cache.state = pickle.loads(cache_bytes)
|
||||
cache.llama.load_state(cache.state)
|
||||
logger.info("Cache successfully loaded from bytes")
|
||||
|
||||
return cache
|
15
core/cache/llama_cache_factory.py
vendored
Normal file
15
core/cache/llama_cache_factory.py
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
from core.cache.base_cache_factory import BaseCacheFactory
|
||||
from core.cache.llama_cache import LlamaCache
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class LlamaCacheFactory(BaseCacheFactory):
|
||||
def create_new_cache(
|
||||
self, name: str, model: str, model_file: str, **kwargs: Dict[str, Any]
|
||||
) -> LlamaCache:
|
||||
return LlamaCache(name, model, model_file, **kwargs)
|
||||
|
||||
def load_cache_from_bytes(
|
||||
self, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs: Dict[str, Any]
|
||||
) -> LlamaCache:
|
||||
return LlamaCache.from_bytes(name, cache_bytes, metadata, **kwargs)
|
@ -1,22 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion generation"""
|
||||
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request for completion generation"""
|
||||
|
||||
query: str
|
||||
context_chunks: List[str]
|
||||
max_tokens: Optional[int] = 1000
|
||||
temperature: Optional[float] = 0.7
|
||||
from core.models.completion import CompletionRequest, CompletionResponse
|
||||
|
||||
|
||||
class BaseCompletionModel(ABC):
|
||||
|
@ -72,3 +72,28 @@ class BaseDatabase(ABC):
|
||||
Returns: True if user has required access, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Store metadata for a cache.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
metadata: Cache metadata including model info and storage location
|
||||
|
||||
Returns:
|
||||
bool: Whether the operation was successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a cache.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
@ -26,6 +26,7 @@ class MongoDatabase(BaseDatabase):
|
||||
self.client = AsyncIOMotorClient(uri)
|
||||
self.db = self.client[db_name]
|
||||
self.collection = self.db[collection_name]
|
||||
self.caches = self.db["caches"] # Collection for cache metadata
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database indexes."""
|
||||
@ -217,3 +218,45 @@ class MongoDatabase(BaseDatabase):
|
||||
for key, value in filters.items():
|
||||
filter_dict[f"metadata.{key}"] = value
|
||||
return filter_dict
|
||||
|
||||
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Store metadata for a cache in MongoDB.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
metadata: Cache metadata including model info and storage location
|
||||
|
||||
Returns:
|
||||
bool: Whether the operation was successful
|
||||
"""
|
||||
try:
|
||||
# Add timestamp and ensure name is included
|
||||
doc = {
|
||||
"name": name,
|
||||
"metadata": metadata,
|
||||
"created_at": datetime.now(UTC),
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
|
||||
# Upsert the document
|
||||
result = await self.caches.update_one({"name": name}, {"$set": doc}, upsert=True)
|
||||
return bool(result.modified_count or result.upserted_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store cache metadata: {e}")
|
||||
return False
|
||||
|
||||
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a cache from MongoDB.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc = await self.caches.find_one({"name": name})
|
||||
return doc["metadata"] if doc else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache metadata: {e}")
|
||||
return None
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, UTC
|
||||
import logging
|
||||
@ -65,8 +66,24 @@ class PostgresDatabase(BaseDatabase):
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Create caches table if it doesn't exist
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS caches (
|
||||
name TEXT PRIMARY KEY,
|
||||
metadata JSONB NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("PostgreSQL tables and indexes created successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating PostgreSQL tables and indexes: {str(e)}")
|
||||
return False
|
||||
@ -324,3 +341,54 @@ class PostgresDatabase(BaseDatabase):
|
||||
filter_conditions.append(f"doc_metadata->>'{key}' = '{value}'")
|
||||
|
||||
return " AND ".join(filter_conditions)
|
||||
|
||||
async def store_cache_metadata(self, name: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Store metadata for a cache in PostgreSQL.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
metadata: Cache metadata including model info and storage location
|
||||
|
||||
Returns:
|
||||
bool: Whether the operation was successful
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO caches (name, metadata, updated_at)
|
||||
VALUES (:name, :metadata, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (name)
|
||||
DO UPDATE SET
|
||||
metadata = :metadata,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"""
|
||||
),
|
||||
{"name": name, "metadata": json.dumps(metadata)},
|
||||
)
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store cache metadata: {e}")
|
||||
return False
|
||||
|
||||
async def get_cache_metadata(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata for a cache from PostgreSQL.
|
||||
|
||||
Args:
|
||||
name: Name of the cache
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Cache metadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT metadata FROM caches WHERE name = :name"), {"name": name}
|
||||
)
|
||||
row = result.first()
|
||||
return row[0] if row else None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache metadata: {e}")
|
||||
return None
|
||||
|
19
core/models/completion.py
Normal file
19
core/models/completion.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion generation"""
|
||||
|
||||
completion: str
|
||||
usage: Dict[str, int]
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request for completion generation"""
|
||||
|
||||
query: str
|
||||
context_chunks: List[str]
|
||||
max_tokens: Optional[int] = 1000
|
||||
temperature: Optional[float] = 0.7
|
@ -40,6 +40,14 @@ class Document(BaseModel):
|
||||
)
|
||||
chunk_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.external_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Document):
|
||||
return False
|
||||
return self.external_id == other.external_id
|
||||
|
||||
|
||||
class DocumentContent(BaseModel):
|
||||
"""Represents either a URL or content string"""
|
||||
|
@ -21,6 +21,8 @@ from core.completion.base_completion import CompletionRequest, CompletionRespons
|
||||
import logging
|
||||
from core.reranker.base_reranker import BaseReranker
|
||||
from core.config import get_settings
|
||||
from core.cache.base_cache import BaseCache
|
||||
from core.cache.base_cache_factory import BaseCacheFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -34,6 +36,7 @@ class DocumentService:
|
||||
parser: BaseParser,
|
||||
embedding_model: BaseEmbeddingModel,
|
||||
completion_model: BaseCompletionModel,
|
||||
cache_factory: BaseCacheFactory,
|
||||
reranker: Optional[BaseReranker] = None,
|
||||
):
|
||||
self.db = database
|
||||
@ -43,6 +46,11 @@ class DocumentService:
|
||||
self.embedding_model = embedding_model
|
||||
self.completion_model = completion_model
|
||||
self.reranker = reranker
|
||||
self.cache_factory = cache_factory
|
||||
|
||||
# Cache-related data structures
|
||||
# Maps cache name to active cache object
|
||||
self.active_caches: Dict[str, BaseCache] = {}
|
||||
|
||||
async def retrieve_chunks(
|
||||
self,
|
||||
@ -151,6 +159,7 @@ class DocumentService:
|
||||
},
|
||||
)
|
||||
logger.info(f"Created text document record with ID {doc.external_id}")
|
||||
doc.system_metadata["content"] = request.content
|
||||
|
||||
# 2. Parse content into chunks
|
||||
chunks = await self.parser.split_text(request.content)
|
||||
@ -196,6 +205,7 @@ class DocumentService:
|
||||
},
|
||||
additional_metadata=additional_metadata,
|
||||
)
|
||||
doc.system_metadata["content"] = "\n".join(chunk.content for chunk in chunks)
|
||||
logger.info(f"Created file document record with ID {doc.external_id}")
|
||||
|
||||
storage_info = await self.storage.upload_from_base64(
|
||||
@ -333,3 +343,90 @@ class DocumentService:
|
||||
|
||||
logger.info(f"Created {len(results)} document results")
|
||||
return results
|
||||
|
||||
async def create_cache(
|
||||
self,
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
docs: List[Document | None],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Create a new cache with specified configuration.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to create
|
||||
model: Name of the model to use
|
||||
gguf_file: Name of the GGUF file to use
|
||||
filters: Optional metadata filters for documents to include
|
||||
docs: Optional list of specific document IDs to include
|
||||
"""
|
||||
# Create cache metadata
|
||||
metadata = {
|
||||
"model": model,
|
||||
"model_file": gguf_file,
|
||||
"filters": filters,
|
||||
"docs": [doc.model_dump_json() for doc in docs],
|
||||
"storage_info": {
|
||||
"bucket": "caches",
|
||||
"key": f"{name}_state.pkl",
|
||||
},
|
||||
}
|
||||
|
||||
# Store metadata in database
|
||||
success = await self.db.store_cache_metadata(name, metadata)
|
||||
if not success:
|
||||
logger.error(f"Failed to store cache metadata for cache {name}")
|
||||
return {"success": False, "message": f"Failed to store cache metadata for cache {name}"}
|
||||
|
||||
# Create cache instance
|
||||
cache = self.cache_factory.create_new_cache(
|
||||
name=name, model=model, model_file=gguf_file, filters=filters, docs=docs
|
||||
)
|
||||
cache_bytes = cache.saveable_state
|
||||
base64_cache_bytes = base64.b64encode(cache_bytes).decode()
|
||||
bucket, key = await self.storage.upload_from_base64(
|
||||
base64_cache_bytes,
|
||||
key=metadata["storage_info"]["key"],
|
||||
bucket=metadata["storage_info"]["bucket"],
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Cache created successfully, state stored in bucket `{bucket}` with key `{key}`",
|
||||
}
|
||||
|
||||
async def load_cache(self, name: str) -> bool:
|
||||
"""Load a cache into memory.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to load
|
||||
|
||||
Returns:
|
||||
bool: Whether the cache exists and was loaded successfully
|
||||
"""
|
||||
try:
|
||||
# Get cache metadata from database
|
||||
metadata = await self.db.get_cache_metadata(name)
|
||||
if not metadata:
|
||||
logger.error(f"No metadata found for cache {name}")
|
||||
return False
|
||||
|
||||
# Get cache bytes from storage
|
||||
cache_bytes = await self.storage.download_file(
|
||||
metadata["storage_info"]["bucket"], "caches/" + metadata["storage_info"]["key"]
|
||||
)
|
||||
cache_bytes = cache_bytes.read()
|
||||
cache = self.cache_factory.load_cache_from_bytes(
|
||||
name=name, cache_bytes=cache_bytes, metadata=metadata
|
||||
)
|
||||
self.active_caches[name] = cache
|
||||
return {"success": True, "message": "Cache loaded successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cache {name}: {e}")
|
||||
# raise e
|
||||
return {"success": False, "message": f"Failed to load cache {name}: {e}"}
|
||||
|
||||
def close(self):
|
||||
"""Close all resources."""
|
||||
# Close any active caches
|
||||
self.active_caches.clear()
|
||||
|
@ -7,7 +7,7 @@ class BaseStorage(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def upload_from_base64(
|
||||
self, content: str, key: str, content_type: Optional[str] = None
|
||||
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Upload base64 encoded content.
|
||||
@ -16,7 +16,7 @@ class BaseStorage(ABC):
|
||||
content: Base64 encoded content
|
||||
key: Storage key/path
|
||||
content_type: Optional MIME type
|
||||
|
||||
bucket: Optional bucket/folder name
|
||||
Returns:
|
||||
Tuple[str, str]: (bucket/container name, storage key)
|
||||
"""
|
||||
|
@ -19,16 +19,20 @@ class LocalStorage(BaseStorage):
|
||||
return open(file_path, "rb")
|
||||
|
||||
async def upload_from_base64(
|
||||
self, base64_content: str, key: str, content_type: Optional[str] = None
|
||||
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
base64_content = content
|
||||
"""Upload base64 encoded content to local storage."""
|
||||
# Decode base64 content
|
||||
file_content = base64.b64decode(base64_content)
|
||||
|
||||
key = f"{bucket}/{key}" if bucket else key
|
||||
# Create file path
|
||||
file_path = self.storage_path / key
|
||||
|
||||
# Write content to file
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.unlink(missing_ok=True)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
@ -37,6 +37,7 @@ class S3Storage(BaseStorage):
|
||||
file: Union[str, bytes, BinaryIO],
|
||||
key: str,
|
||||
content_type: Optional[str] = None,
|
||||
bucket: str = "",
|
||||
) -> Tuple[str, str]:
|
||||
"""Upload a file to S3."""
|
||||
try:
|
||||
@ -70,15 +71,18 @@ class S3Storage(BaseStorage):
|
||||
raise
|
||||
|
||||
async def upload_from_base64(
|
||||
self, content: str, key: str, content_type: Optional[str] = None
|
||||
self, content: str, key: str, content_type: Optional[str] = None, bucket: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""Upload base64 encoded content to S3."""
|
||||
key = f"{bucket}/{key}" if bucket else key
|
||||
try:
|
||||
decoded_content = base64.b64decode(content)
|
||||
extension = detect_file_type(content)
|
||||
key = f"{key}{extension}"
|
||||
|
||||
return await self.upload_file(file=decoded_content, key=key, content_type=content_type)
|
||||
return await self.upload_file(
|
||||
file=decoded_content, key=key, content_type=content_type, bucket=bucket
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading base64 content to S3: {e}")
|
||||
|
@ -238,6 +238,8 @@ async def test_ingest_oversized_content(client: AsyncClient):
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_missing_header(client: AsyncClient):
|
||||
"""Test authentication with missing auth header"""
|
||||
if get_settings().dev_mode:
|
||||
pytest.skip("Auth tests skipped in dev mode")
|
||||
response = await client.post("/ingest/text")
|
||||
assert response.status_code == 401
|
||||
|
||||
@ -245,6 +247,8 @@ async def test_auth_missing_header(client: AsyncClient):
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_invalid_token(client: AsyncClient):
|
||||
"""Test authentication with invalid token"""
|
||||
if get_settings().dev_mode:
|
||||
pytest.skip("Auth tests skipped in dev mode")
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.post("/ingest/file", headers=headers)
|
||||
assert response.status_code == 401
|
||||
@ -253,6 +257,8 @@ async def test_auth_invalid_token(client: AsyncClient):
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_expired_token(client: AsyncClient):
|
||||
"""Test authentication with expired token"""
|
||||
if get_settings().dev_mode:
|
||||
pytest.skip("Auth tests skipped in dev mode")
|
||||
headers = create_auth_header(expired=True)
|
||||
response = await client.post("/ingest/text", headers=headers)
|
||||
assert response.status_code == 401
|
||||
@ -261,6 +267,8 @@ async def test_auth_expired_token(client: AsyncClient):
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_insufficient_permissions(client: AsyncClient):
|
||||
"""Test authentication with insufficient permissions"""
|
||||
if get_settings().dev_mode:
|
||||
pytest.skip("Auth tests skipped in dev mode")
|
||||
headers = create_auth_header(permissions=["read"])
|
||||
response = await client.post(
|
||||
"/ingest/text",
|
||||
@ -335,7 +343,7 @@ async def test_retrieve_chunks(client: AsyncClient):
|
||||
assert response.status_code == 200
|
||||
results = list(response.json())
|
||||
assert len(results) > 0
|
||||
assert results[0]["score"] > 0.5
|
||||
assert (not get_settings().USE_RERANKING) or results[0]["score"] > 0.5
|
||||
assert any(upload_string == result["content"] for result in results)
|
||||
|
||||
|
||||
|
184
core/tests/unit/test_cache.py
Normal file
184
core/tests/unit/test_cache.py
Normal file
@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
from core.models.documents import Document
|
||||
from core.cache.llama_cache import LlamaCache
|
||||
from core.models.completion import CompletionResponse
|
||||
|
||||
# TEST_MODEL = "QuantFactory/Llama3.2-3B-Enigma-GGUF"
|
||||
TEST_MODEL = "QuantFactory/Dolphin3.0-Llama3.2-1B-GGUF"
|
||||
# TEST_MODEL = "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF"
|
||||
TEST_GGUF_FILE = "*Q4_K_S.gguf"
|
||||
# TEST_GGUF_FILE = "*Q4_K_M.gguf"
|
||||
|
||||
|
||||
def get_test_document():
|
||||
"""Load the example.txt file as a test document."""
|
||||
# test_file = Path(__file__).parent.parent / "assets" / "example.txt"
|
||||
# with open(test_file, "r") as f:
|
||||
# content = f.read()
|
||||
|
||||
content = """
|
||||
In darkest hours of coding's fierce domain,
|
||||
Where bugs lurk deep in shadows, hard to find,
|
||||
Each error message brings fresh waves of pain,
|
||||
As stack traces drive madness through the mind.
|
||||
|
||||
Through endless loops of print statements we wade,
|
||||
Debug flags raised like torches in the night,
|
||||
While segfaults mock each careful plan we made,
|
||||
And race conditions laugh at our plight.
|
||||
|
||||
O rubber duck, my silent debugging friend,
|
||||
Your plastic gaze holds wisdom yet untold,
|
||||
As line by line we trace paths without end,
|
||||
Seeking that elusive bug of gold.
|
||||
|
||||
Yet hope remains while coffee still flows strong,
|
||||
Through debugging hell, we'll debug right from wrong.
|
||||
""".strip()
|
||||
|
||||
return Document(
|
||||
external_id="alice_ch1",
|
||||
owner={"id": "test_user", "name": "Test User"},
|
||||
content_type="text/plain",
|
||||
system_metadata={
|
||||
"content": content,
|
||||
"title": "Alice in Wonderland - Chapter 1",
|
||||
"source": "test_document",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_cache():
|
||||
"""Create a LlamaCache instance with the test document."""
|
||||
doc = get_test_document()
|
||||
cache = LlamaCache(
|
||||
name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc]
|
||||
)
|
||||
return cache
|
||||
|
||||
|
||||
def test_basic_rag_capabilities(llama_cache):
|
||||
"""Test that the cache can answer basic questions about the document content."""
|
||||
# Test question about whether ingestion is actually happening
|
||||
response = llama_cache.query(
|
||||
"Summarize the content of the document. Please respond in a single sentence. Summary: "
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
# assert "alice" in response.completion.lower()
|
||||
|
||||
# # Test question about a specific detail
|
||||
# response = llama_cache.query(
|
||||
# "What did Alice see the White Rabbit do with its watch? Please respond in a single sentence. Answer: "
|
||||
# )
|
||||
# assert isinstance(response, CompletionResponse)
|
||||
# # assert "waistcoat-pocket" in response.completion.lower() or "looked at it" in response.completion.lower()
|
||||
|
||||
# # Test question about character description
|
||||
# response = llama_cache.query(
|
||||
# "How did Alice's size change during the story? Please respond in a single sentence. Answer: "
|
||||
# )
|
||||
# assert isinstance(response, CompletionResponse)
|
||||
# # assert any(phrase in response.completion.lower() for phrase in ["grew larger", "grew smaller", "nine feet", "telescope"])
|
||||
|
||||
# # Test question about plot elements
|
||||
# response = llama_cache.query(
|
||||
# "What was written on the bottle Alice found? Please respond in a single sentence. Answer: "
|
||||
# )
|
||||
# assert isinstance(response, CompletionResponse)
|
||||
# # assert "drink me" in response.completion.lower()
|
||||
|
||||
|
||||
# def test_cache_memory_persistence(llama_cache):
|
||||
# """Test that the cache maintains context across multiple queries."""
|
||||
|
||||
# # First query to establish context
|
||||
# llama_cache.query(
|
||||
# "What was Alice doing before she saw the White Rabbit? Please respond in a single sentence. Answer: "
|
||||
# )
|
||||
|
||||
# # Follow-up query that requires remembering previous context
|
||||
# response = llama_cache.query(
|
||||
# "What book was her sister reading? Please respond in a single sentence. Answer: "
|
||||
# )
|
||||
# assert isinstance(response, CompletionResponse)
|
||||
# # assert "no pictures" in response.completion.lower() or "conversations" in response.completion.lower()
|
||||
|
||||
|
||||
def test_adding_new_documents(llama_cache):
|
||||
"""Test that the cache can incorporate new documents into its knowledge."""
|
||||
|
||||
# Create a new document with additional content
|
||||
new_doc = Document(
|
||||
external_id="alice_ch2",
|
||||
owner={"id": "test_user", "name": "Test User"},
|
||||
content_type="text/plain",
|
||||
system_metadata={
|
||||
"content": "Alice found herself in a pool of tears. She met a Mouse swimming in the pool.",
|
||||
"title": "Alice in Wonderland - Additional Content",
|
||||
"source": "test_document",
|
||||
},
|
||||
)
|
||||
|
||||
# Add the new document
|
||||
success = llama_cache.add_docs([new_doc])
|
||||
assert success
|
||||
|
||||
# Query about the new content
|
||||
response = llama_cache.query(
|
||||
"What did Alice find in the pool of tears? Please respond in a single sentence. Answer: "
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "mouse" in response.completion.lower()
|
||||
|
||||
|
||||
def test_cache_state_persistence():
|
||||
"""Test that the cache state can be saved and loaded."""
|
||||
|
||||
# Create initial cache
|
||||
doc = get_test_document()
|
||||
original_cache = LlamaCache(
|
||||
name="test_cache", model=TEST_MODEL, gguf_file=TEST_GGUF_FILE, filters={}, docs=[doc]
|
||||
)
|
||||
|
||||
# Get the state
|
||||
state_bytes = original_cache.saveable_state
|
||||
# Save state bytes to temporary file
|
||||
import tempfile
|
||||
import os
|
||||
import pickle
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cache_file = os.path.join(temp_dir, "cache.pkl")
|
||||
|
||||
# Save to file
|
||||
with open(cache_file, "wb") as f:
|
||||
pickle.dump(state_bytes, f)
|
||||
|
||||
# Load from file
|
||||
with open(cache_file, "rb") as f:
|
||||
loaded_state_bytes = pickle.load(f)
|
||||
|
||||
# # Verify state bytes match
|
||||
# assert state_bytes == loaded_state_bytes
|
||||
# state_bytes = loaded_state_bytes # Use loaded bytes for rest of test
|
||||
|
||||
# Create new cache from state
|
||||
loaded_cache = LlamaCache.from_bytes(
|
||||
name="test_cache",
|
||||
cache_bytes=loaded_state_bytes,
|
||||
metadata={
|
||||
"model": TEST_MODEL,
|
||||
"model_file": TEST_GGUF_FILE,
|
||||
"filters": {},
|
||||
"docs": [doc.model_dump_json()],
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the loaded cache works
|
||||
response = loaded_cache.query(
|
||||
"Summarize the content of the document. Please respond in a single sentence. Summary: "
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "coding" in response.completion.lower() or "debug" in response.completion.lower()
|
||||
# assert "bottle" in response.completion.lower() and "drink me" in response.completion.lower()
|
@ -37,6 +37,8 @@ def sample_chunks() -> List[DocumentChunk]:
|
||||
def reranker():
|
||||
"""Fixture to create and reuse a flag reranker instance"""
|
||||
settings = get_settings()
|
||||
if not settings.USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
return FlagReranker(
|
||||
model_name=settings.RERANKER_MODEL,
|
||||
device=settings.RERANKER_DEVICE,
|
||||
@ -49,6 +51,8 @@ def reranker():
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_relevance(reranker, sample_chunks):
|
||||
"""Test that reranker improves relevance for programming-related query"""
|
||||
if not get_settings().USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
print("\n=== Testing Reranker Relevance ===")
|
||||
query = "What is Python programming language?"
|
||||
|
||||
@ -70,6 +74,8 @@ async def test_reranker_relevance(reranker, sample_chunks):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_score_distribution(reranker, sample_chunks):
|
||||
"""Test that reranker produces reasonable score distribution"""
|
||||
if not get_settings().USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
print("\n=== Testing Score Distribution ===")
|
||||
query = "Tell me about machine learning and data science"
|
||||
|
||||
@ -95,6 +101,8 @@ async def test_reranker_score_distribution(reranker, sample_chunks):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_batch_scoring(reranker):
|
||||
"""Test that reranker can handle multiple queries/passages efficiently"""
|
||||
if not get_settings().USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
print("\n=== Testing Batch Scoring ===")
|
||||
texts = [
|
||||
"Python is a programming language",
|
||||
@ -118,6 +126,8 @@ async def test_reranker_batch_scoring(reranker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
|
||||
"""Test reranker behavior with empty or edge case inputs"""
|
||||
if not get_settings().USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
print("\n=== Testing Edge Cases ===")
|
||||
|
||||
# Empty chunks list
|
||||
@ -147,6 +157,8 @@ async def test_reranker_empty_and_edge_cases(reranker, sample_chunks):
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_consistency(reranker, sample_chunks):
|
||||
"""Test that reranker produces consistent results for same input"""
|
||||
if not get_settings().USE_RERANKING:
|
||||
pytest.skip("Reranker is disabled in settings")
|
||||
print("\n=== Testing Consistency ===")
|
||||
query = "What is Python programming?"
|
||||
|
||||
|
11
dockerfile
11
dockerfile
@ -9,6 +9,8 @@ WORKDIR /app
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
cmake \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -16,6 +18,9 @@ RUN apt-get update && apt-get install -y \
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --user -r requirements.txt
|
||||
|
||||
# Download NLTK data
|
||||
RUN python -m nltk.downloader -d /usr/local/share/nltk_data punkt averaged_perceptron_tagger
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12.5-slim
|
||||
|
||||
@ -31,11 +36,17 @@ RUN apt-get update && apt-get install -y \
|
||||
tesseract-ocr \
|
||||
postgresql-client \
|
||||
poppler-utils \
|
||||
gcc \
|
||||
g++ \
|
||||
cmake \
|
||||
python3-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /root/.local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
COPY --from=builder /root/.local/bin /usr/local/bin
|
||||
# Copy NLTK data from builder
|
||||
COPY --from=builder /usr/local/share/nltk_data /usr/local/share/nltk_data
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p storage logs
|
||||
|
12
init.sql
12
init.sql
@ -31,7 +31,15 @@ CREATE TABLE IF NOT EXISTS vector_embeddings (
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Create caches table
|
||||
CREATE TABLE IF NOT EXISTS caches (
|
||||
name TEXT PRIMARY KEY,
|
||||
metadata JSON NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Create vector index
|
||||
CREATE INDEX IF NOT EXISTS vector_idx
|
||||
CREATE INDEX IF NOT EXISTS vector_idx
|
||||
ON vector_embeddings USING ivfflat (embedding vector_l2_ops)
|
||||
WITH (lists = 100);
|
||||
WITH (lists = 100);
|
||||
|
@ -241,12 +241,23 @@ def setup_postgres():
|
||||
|
||||
# Import and create all tables
|
||||
from core.database.postgres_database import Base
|
||||
from core.vector_store.pgvector_store import Base as VectorBase
|
||||
|
||||
# Create regular tables first
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
LOGGER.info("Created base PostgreSQL tables")
|
||||
|
||||
# Create caches table
|
||||
create_caches_table = """
|
||||
CREATE TABLE IF NOT EXISTS caches (
|
||||
name TEXT PRIMARY KEY,
|
||||
metadata JSON NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
await conn.execute(text(create_caches_table))
|
||||
LOGGER.info("Created caches table")
|
||||
|
||||
# Get vector dimensions from config
|
||||
dimensions = CONFIG["embedding"]["dimensions"]
|
||||
|
||||
@ -278,8 +289,8 @@ def setup_postgres():
|
||||
LOGGER.info("Created vector_embeddings table with vector column")
|
||||
|
||||
# Create the vector index
|
||||
index_sql = f"""
|
||||
CREATE INDEX vector_idx
|
||||
index_sql = """
|
||||
CREATE INDEX vector_idx
|
||||
ON vector_embeddings USING ivfflat (embedding vector_l2_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
|
@ -293,3 +293,4 @@ zlib-state==0.1.9
|
||||
pgvector==0.2.5
|
||||
psycopg[binary]==3.1.18
|
||||
psycopg-binary==3.1.18
|
||||
llama-cpp-python==0.3.5
|
||||
|
@ -13,4 +13,4 @@ __all__ = [
|
||||
"IngestTextRequest",
|
||||
]
|
||||
|
||||
__version__ = "0.1.8"
|
||||
__version__ = "0.2.0"
|
||||
|
@ -16,6 +16,30 @@ from .models import (
|
||||
)
|
||||
|
||||
|
||||
class AsyncCache:
|
||||
def __init__(self, db: "AsyncDataBridge", name: str):
|
||||
self._db = db
|
||||
self._name = name
|
||||
|
||||
async def update(self) -> bool:
|
||||
response = await self._db._request("POST", f"cache/{self._name}/update")
|
||||
return response.get("success", False)
|
||||
|
||||
async def add_docs(self, docs: List[str]) -> bool:
|
||||
response = await self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs})
|
||||
return response.get("success", False)
|
||||
|
||||
async def query(
|
||||
self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None
|
||||
) -> str:
|
||||
response = await self._db._request(
|
||||
"POST",
|
||||
f"cache/{self._name}/query",
|
||||
{"query": query, "max_tokens": max_tokens, "temperature": temperature},
|
||||
)
|
||||
return CompletionResponse(**response)
|
||||
|
||||
|
||||
class AsyncDataBridge:
|
||||
"""
|
||||
DataBridge client for document operations.
|
||||
@ -345,6 +369,72 @@ class AsyncDataBridge:
|
||||
response = await self._request("GET", f"documents/{document_id}")
|
||||
return Document(**response)
|
||||
|
||||
async def create_cache(
|
||||
self,
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cache with specified configuration.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to create
|
||||
model: Name of the model to use (e.g. "llama2")
|
||||
gguf_file: Name of the GGUF file to use for the model
|
||||
filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided.
|
||||
docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Created cache configuration
|
||||
|
||||
Example:
|
||||
```python
|
||||
# This will include both:
|
||||
# 1. Any documents with category="programming"
|
||||
# 2. The specific documents "doc1" and "doc2" (regardless of their category)
|
||||
cache = await db.create_cache(
|
||||
name="programming_cache",
|
||||
model="llama2",
|
||||
gguf_file="llama-2-7b-chat.Q4_K_M.gguf",
|
||||
filters={"category": "programming"},
|
||||
docs=["doc1", "doc2"]
|
||||
)
|
||||
```
|
||||
"""
|
||||
request = {
|
||||
"name": name,
|
||||
"model": model,
|
||||
"gguf_file": gguf_file,
|
||||
"filters": filters,
|
||||
"docs": docs,
|
||||
}
|
||||
|
||||
response = await self._request("POST", "cache/create", request)
|
||||
return response
|
||||
|
||||
async def get_cache(self, name: str) -> AsyncCache:
|
||||
"""
|
||||
Get a cache by name.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to retrieve
|
||||
|
||||
Returns:
|
||||
cache: A cache object that is used to interact with the cache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
cache = await db.get_cache("programming_cache")
|
||||
```
|
||||
"""
|
||||
response = await self._request("GET", f"cache/{name}")
|
||||
if response.get("exists", False):
|
||||
return AsyncCache(self, name)
|
||||
raise ValueError(f"Cache '{name}' not found")
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client"""
|
||||
await self._client.aclose()
|
||||
|
@ -16,6 +16,31 @@ from .models import (
|
||||
)
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, db: "DataBridge", name: str):
|
||||
self._db = db
|
||||
self._name = name
|
||||
|
||||
def update(self) -> bool:
|
||||
response = self._db._request("POST", f"cache/{self._name}/update")
|
||||
return response.get("success", False)
|
||||
|
||||
def add_docs(self, docs: List[str]) -> bool:
|
||||
response = self._db._request("POST", f"cache/{self._name}/add_docs", {"docs": docs})
|
||||
return response.get("success", False)
|
||||
|
||||
def query(
|
||||
self, query: str, max_tokens: Optional[int] = None, temperature: Optional[float] = None
|
||||
) -> CompletionResponse:
|
||||
response = self._db._request(
|
||||
"POST",
|
||||
f"cache/{self._name}/query",
|
||||
params={"query": query, "max_tokens": max_tokens, "temperature": temperature},
|
||||
data="",
|
||||
)
|
||||
return CompletionResponse(**response)
|
||||
|
||||
|
||||
class DataBridge:
|
||||
"""
|
||||
DataBridge client for document operations.
|
||||
@ -71,6 +96,7 @@ class DataBridge:
|
||||
endpoint: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
files: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make HTTP request"""
|
||||
headers = {}
|
||||
@ -88,6 +114,7 @@ class DataBridge:
|
||||
data=data if files else None,
|
||||
headers=headers,
|
||||
timeout=self._timeout,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@ -334,6 +361,70 @@ class DataBridge:
|
||||
response = self._request("GET", f"documents/{document_id}")
|
||||
return Document(**response)
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cache with specified configuration.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to create
|
||||
model: Name of the model to use (e.g. "llama2")
|
||||
gguf_file: Name of the GGUF file to use for the model
|
||||
filters: Optional metadata filters to determine which documents to include. These filters will be applied in addition to any specific docs provided.
|
||||
docs: Optional list of specific document IDs to include. These docs will be included in addition to any documents matching the filters.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Created cache configuration
|
||||
|
||||
Example:
|
||||
```python
|
||||
# This will include both:
|
||||
# 1. Any documents with category="programming"
|
||||
# 2. The specific documents "doc1" and "doc2" (regardless of their category)
|
||||
cache = db.create_cache(
|
||||
name="programming_cache",
|
||||
model="llama2",
|
||||
gguf_file="llama-2-7b-chat.Q4_K_M.gguf",
|
||||
filters={"category": "programming"},
|
||||
docs=["doc1", "doc2"]
|
||||
)
|
||||
```
|
||||
"""
|
||||
# Build query parameters for name, model and gguf_file
|
||||
params = {"name": name, "model": model, "gguf_file": gguf_file}
|
||||
|
||||
# Build request body for filters and docs
|
||||
request = {"filters": filters, "docs": docs}
|
||||
|
||||
response = self._request("POST", "cache/create", request, params=params)
|
||||
return response
|
||||
|
||||
def get_cache(self, name: str) -> Cache:
|
||||
"""
|
||||
Get a cache by name.
|
||||
|
||||
Args:
|
||||
name: Name of the cache to retrieve
|
||||
|
||||
Returns:
|
||||
cache: A cache object that is used to interact with the cache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
cache = db.get_cache("programming_cache")
|
||||
```
|
||||
"""
|
||||
response = self._request("GET", f"cache/{name}")
|
||||
if response.get("exists", False):
|
||||
return Cache(self, name)
|
||||
raise ValueError(f"Cache '{name}' not found")
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP session"""
|
||||
self._session.close()
|
||||
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "databridge-client"
|
||||
version = "0.1.8"
|
||||
version = "0.2.0"
|
||||
authors = [
|
||||
{ name = "DataBridge", email = "databridgesuperuser@gmail.com" },
|
||||
]
|
||||
|
48
shell.py
48
shell.py
@ -118,11 +118,59 @@ class DB:
|
||||
doc = self._client.get_document(document_id)
|
||||
return doc.model_dump()
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
name: str,
|
||||
model: str,
|
||||
gguf_file: str,
|
||||
filters: dict = None,
|
||||
docs: list = None,
|
||||
) -> dict:
|
||||
"""Create a new cache with specified configuration"""
|
||||
response = self._client.create_cache(
|
||||
name=name,
|
||||
model=model,
|
||||
gguf_file=gguf_file,
|
||||
filters=filters or {},
|
||||
docs=docs,
|
||||
)
|
||||
return response
|
||||
|
||||
def get_cache(self, name: str) -> "Cache":
|
||||
"""Get a cache by name"""
|
||||
return self._client.get_cache(name)
|
||||
|
||||
def close(self):
|
||||
"""Close the client connection"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, db: DB, name: str):
|
||||
self._db = db
|
||||
self._name = name
|
||||
self._client_cache = db._client.get_cache(name)
|
||||
|
||||
def update(self) -> bool:
|
||||
"""Update the cache"""
|
||||
return self._client_cache.update()
|
||||
|
||||
def add_docs(self, docs: list) -> bool:
|
||||
"""Add documents to the cache"""
|
||||
return self._client_cache.add_docs(docs)
|
||||
|
||||
def query(
|
||||
self, query: str, max_tokens: int = None, temperature: float = None
|
||||
) -> dict:
|
||||
"""Query the cache"""
|
||||
response = self._client_cache.query(
|
||||
query=query,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uri = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
db = DB(uri)
|
||||
|
Loading…
x
Reference in New Issue
Block a user