morphik-core/core/cache/llama_cache.py
Adityavardhan Agrawal 1792275cb8
Format fix, UI package update (#100)
Co-authored-by: Arnav Agrawal <aa779@cornell.edu>
2025-04-20 16:34:29 -07:00

197 lines
7.2 KiB
Python

import json
import logging
import pickle
from typing import Any, Dict, List
from llama_cpp import Llama
from core.cache.base_cache import BaseCache
from core.models.completion import CompletionResponse
from core.models.documents import Document
logger = logging.getLogger(__name__)
INITIAL_SYSTEM_PROMPT = """<|im_start|>system
You are a helpful AI assistant with access to provided documents. Your role is to:
1. Answer questions accurately based on the documents provided
2. Stay focused on the document content and avoid speculation
3. Admit when you don't have enough information to answer
4. Be clear and concise in your responses
5. Use direct quotes from documents when relevant
Provided documents: {documents}
<|im_end|>
""".strip()
ADD_DOC_SYSTEM_PROMPT = """<|im_start|>system
I'm adding some additional documents for your reference:
{documents}
Please incorporate this new information along with what you already know from previous documents while maintaining the same guidelines for responses.
<|im_end|>
""".strip()
QUERY_PROMPT = """<|im_start|>user
{query}
<|im_end|>
<|im_start|>assistant
""".strip()
class LlamaCache(BaseCache):
def __init__(
self,
name: str,
model: str,
gguf_file: str,
filters: Dict[str, Any],
docs: List[Document],
**kwargs,
):
logger.info(f"Initializing LlamaCache with name={name}, model={model}")
# cache related
self.name = name
self.model = model
self.filters = filters
self.docs = docs
# llama specific
self.gguf_file = gguf_file
self.n_gpu_layers = kwargs.get("n_gpu_layers", -1)
logger.info(f"Using {self.n_gpu_layers} GPU layers")
# late init (when we call _initialize)
self.llama = None
self.state = None
self.cached_tokens = 0
self._initialize(model, gguf_file, docs)
logger.info("LlamaCache initialization complete")
def _initialize(self, model: str, gguf_file: str, docs: List[Document]) -> None:
logger.info(f"Loading Llama model from {model} with file {gguf_file}")
try:
# Set a reasonable default context size (32K tokens)
default_ctx_size = 32768
self.llama = Llama.from_pretrained(
repo_id=model,
filename=gguf_file,
n_gpu_layers=self.n_gpu_layers,
n_ctx=default_ctx_size,
verbose=False, # Enable verbose mode for better error reporting
)
logger.info("Model loaded successfully")
# Format and tokenize system prompt
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = INITIAL_SYSTEM_PROMPT.format(documents=documents)
logger.info(f"Built system prompt: {system_prompt[:200]}...")
try:
tokens = self.llama.tokenize(system_prompt.encode())
logger.info(f"System prompt tokenized to {len(tokens)} tokens")
# Process tokens to build KV cache
logger.info("Evaluating system prompt")
self.llama.eval(tokens)
logger.info("Saving initial KV cache state")
self.state = self.llama.save_state()
self.cached_tokens = len(tokens)
logger.info(f"Initial KV cache built with {self.cached_tokens} tokens")
except Exception as e:
logger.error(f"Error during prompt processing: {str(e)}")
raise ValueError(f"Failed to process system prompt: {str(e)}")
except Exception as e:
logger.error(f"Failed to initialize Llama model: {str(e)}")
raise ValueError(f"Failed to initialize Llama model: {str(e)}")
def add_docs(self, docs: List[Document]) -> bool:
logger.info(f"Adding {len(docs)} new documents to cache")
documents = "\n".join(doc.system_metadata.get("content", "") for doc in docs)
system_prompt = ADD_DOC_SYSTEM_PROMPT.format(documents=documents)
# Tokenize and process
new_tokens = self.llama.tokenize(system_prompt.encode())
self.llama.eval(new_tokens)
self.state = self.llama.save_state()
self.cached_tokens += len(new_tokens)
logger.info(f"Added {len(new_tokens)} tokens, total: {self.cached_tokens}")
return True
def query(self, query: str) -> CompletionResponse:
# Format query with proper chat template
formatted_query = QUERY_PROMPT.format(query=query)
logger.info(f"Processing query: {formatted_query}")
# Reset and load cached state
self.llama.reset()
self.llama.load_state(self.state)
logger.info(f"Loaded state with {self.state.n_tokens} tokens")
# print(f"Loaded state with {self.state.n_tokens} tokens", file=sys.stderr)
# Tokenize and process query
query_tokens = self.llama.tokenize(formatted_query.encode())
self.llama.eval(query_tokens)
logger.info(f"Evaluated query tokens: {query_tokens}")
# print(f"Evaluated query tokens: {query_tokens}", file=sys.stderr)
# Generate response
output_tokens = []
for token in self.llama.generate(tokens=[], reset=False):
output_tokens.append(token)
# Stop generation when EOT token is encountered
if token == self.llama.token_eos():
break
# Decode and return
completion = self.llama.detokenize(output_tokens).decode()
logger.info(f"Generated completion: {completion}")
return CompletionResponse(
completion=completion,
usage={"prompt_tokens": self.cached_tokens, "completion_tokens": len(output_tokens)},
)
@property
def saveable_state(self) -> bytes:
logger.info("Serializing cache state")
state_bytes = pickle.dumps(self.state)
logger.info(f"Serialized state size: {len(state_bytes)} bytes")
return state_bytes
@classmethod
def from_bytes(cls, name: str, cache_bytes: bytes, metadata: Dict[str, Any], **kwargs) -> "LlamaCache":
"""Load a cache from its serialized state.
Args:
name: Name of the cache
cache_bytes: Pickled state bytes
metadata: Cache metadata including model info
**kwargs: Additional arguments
Returns:
LlamaCache: Loaded cache instance
"""
logger.info(f"Loading cache from bytes with name={name}")
logger.info(f"Cache metadata: {metadata}")
# Create new instance with metadata
# logger.info(f"Docs: {metadata['docs']}")
docs = [json.loads(doc) for doc in metadata["docs"]]
# time.sleep(10)
cache = cls(
name=name,
model=metadata["model"],
gguf_file=metadata["model_file"],
filters=metadata["filters"],
docs=[Document(**doc) for doc in docs],
)
# Load the saved state
logger.info(f"Loading saved KV cache state of size {len(cache_bytes)} bytes")
cache.state = pickle.loads(cache_bytes)
cache.llama.load_state(cache.state)
logger.info("Cache successfully loaded from bytes")
return cache