2025-01-29 10:19:28 +05:30
|
|
|
# hugging face cache implementation.
|
|
|
|
|
|
|
|
from pathlib import Path
|
2025-04-20 16:34:29 -07:00
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
2025-01-29 10:19:28 +05:30
|
|
|
import torch
|
2025-04-20 16:34:29 -07:00
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2025-01-29 10:19:28 +05:30
|
|
|
from transformers.cache_utils import DynamicCache
|
2025-04-20 16:34:29 -07:00
|
|
|
|
|
|
|
from core.cache.base_cache import BaseCache
|
2025-01-29 10:19:28 +05:30
|
|
|
from core.models.completion import CompletionRequest, CompletionResponse
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceCache(BaseCache):
|
|
|
|
"""Hugging Face Cache implementation for cache-augmented generation"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
cache_path: Path,
|
|
|
|
model_name: str = "distilgpt2",
|
|
|
|
device: str = "cpu",
|
|
|
|
default_max_new_tokens: int = 100,
|
|
|
|
use_fp16: bool = False,
|
|
|
|
):
|
|
|
|
"""Initialize the HuggingFace cache.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
cache_path: Path to store cache files
|
|
|
|
model_name: Name of the HuggingFace model to use
|
|
|
|
device: Device to run the model on (e.g. "cpu", "cuda", "mps")
|
|
|
|
default_max_new_tokens: Default maximum number of new tokens to generate
|
|
|
|
use_fp16: Whether to use FP16 precision
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.cache_path = cache_path
|
|
|
|
self.model_name = model_name
|
|
|
|
self.device = device
|
|
|
|
self.default_max_new_tokens = default_max_new_tokens
|
|
|
|
self.use_fp16 = use_fp16
|
|
|
|
|
|
|
|
# Initialize tokenizer and model
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
# Configure model loading based on device
|
|
|
|
model_kwargs = {"low_cpu_mem_usage": True}
|
|
|
|
|
|
|
|
if device == "cpu":
|
|
|
|
# For CPU, use standard loading
|
|
|
|
model_kwargs.update({"torch_dtype": torch.float32})
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs).to(device)
|
|
|
|
else:
|
|
|
|
# For GPU/MPS, use automatic device mapping and optional FP16
|
2025-04-20 16:34:29 -07:00
|
|
|
model_kwargs.update({"device_map": "auto", "torch_dtype": torch.float16 if use_fp16 else torch.float32})
|
2025-01-29 10:19:28 +05:30
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
|
|
|
|
|
|
|
self.kv_cache = None
|
|
|
|
self.origin_len = None
|
|
|
|
|
|
|
|
def get_kv_cache(self, prompt: str) -> DynamicCache:
|
|
|
|
"""Build KV cache from prompt"""
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
|
|
|
cache = DynamicCache()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
_ = self.model(input_ids=input_ids, past_key_values=cache, use_cache=True)
|
|
|
|
return cache
|
|
|
|
|
|
|
|
def clean_up_cache(self, cache: DynamicCache, origin_len: int):
|
|
|
|
"""Clean up cache by removing appended tokens"""
|
|
|
|
for i in range(len(cache.key_cache)):
|
|
|
|
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
|
|
|
|
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
|
|
|
|
|
2025-04-20 16:34:29 -07:00
|
|
|
def generate(self, input_ids: torch.Tensor, past_key_values, max_new_tokens: Optional[int] = None) -> torch.Tensor:
|
2025-01-29 10:19:28 +05:30
|
|
|
"""Generate text using the model and cache"""
|
|
|
|
device = next(self.model.parameters()).device
|
|
|
|
origin_len = input_ids.shape[-1]
|
|
|
|
input_ids = input_ids.to(device)
|
|
|
|
output_ids = input_ids.clone()
|
|
|
|
next_token = input_ids
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
for _ in range(max_new_tokens or self.default_max_new_tokens):
|
2025-04-20 16:34:29 -07:00
|
|
|
out = self.model(input_ids=next_token, past_key_values=past_key_values, use_cache=True)
|
2025-01-29 10:19:28 +05:30
|
|
|
logits = out.logits[:, -1, :]
|
|
|
|
token = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
|
output_ids = torch.cat([output_ids, token], dim=-1)
|
|
|
|
past_key_values = out.past_key_values
|
|
|
|
next_token = token.to(device)
|
|
|
|
|
2025-04-20 16:34:29 -07:00
|
|
|
if self.model.config.eos_token_id is not None and token.item() == self.model.config.eos_token_id:
|
2025-01-29 10:19:28 +05:30
|
|
|
break
|
|
|
|
|
|
|
|
return output_ids[:, origin_len:]
|
|
|
|
|
|
|
|
async def ingest(self, docs: List[str]) -> bool:
|
|
|
|
"""Ingest documents into cache"""
|
|
|
|
try:
|
|
|
|
# Create system prompt with documents
|
|
|
|
system_prompt = f"""
|
|
|
|
<|system|>
|
|
|
|
You are an assistant who provides concise factual answers.
|
|
|
|
<|user|>
|
|
|
|
Context:
|
|
|
|
{' '.join(docs)}
|
|
|
|
Question:
|
|
|
|
""".strip()
|
|
|
|
|
|
|
|
# Build the cache
|
|
|
|
input_ids = self.tokenizer(system_prompt, return_tensors="pt").input_ids.to(self.device)
|
|
|
|
self.kv_cache = DynamicCache()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
# First run to get the cache shape
|
|
|
|
outputs = self.model(input_ids=input_ids, use_cache=True)
|
|
|
|
# Initialize cache with empty tensors of the right shape
|
|
|
|
n_layers = len(outputs.past_key_values)
|
|
|
|
batch_size = input_ids.shape[0]
|
|
|
|
|
|
|
|
# Handle different model architectures
|
|
|
|
|
|
|
|
if hasattr(self.model.config, "num_key_value_heads"):
|
|
|
|
# Models with grouped query attention (GQA) like Llama
|
|
|
|
n_kv_heads = self.model.config.num_key_value_heads
|
|
|
|
head_dim = self.model.config.head_dim
|
|
|
|
elif hasattr(self.model.config, "n_head"):
|
|
|
|
# GPT-style models
|
|
|
|
n_kv_heads = self.model.config.n_head
|
|
|
|
head_dim = self.model.config.n_embd // self.model.config.n_head
|
|
|
|
elif hasattr(self.model.config, "num_attention_heads"):
|
|
|
|
# OPT-style models
|
|
|
|
n_kv_heads = self.model.config.num_attention_heads
|
2025-04-20 16:34:29 -07:00
|
|
|
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
2025-01-29 10:19:28 +05:30
|
|
|
else:
|
2025-04-20 16:34:29 -07:00
|
|
|
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
|
2025-01-29 10:19:28 +05:30
|
|
|
|
|
|
|
seq_len = input_ids.shape[1]
|
|
|
|
|
|
|
|
for i in range(n_layers):
|
|
|
|
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
|
|
|
|
value_shape = key_shape
|
|
|
|
self.kv_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
|
|
|
|
self.kv_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
|
|
|
|
|
|
|
|
# Now run with the initialized cache
|
2025-04-20 16:34:29 -07:00
|
|
|
outputs = self.model(input_ids=input_ids, past_key_values=self.kv_cache, use_cache=True)
|
2025-01-29 10:19:28 +05:30
|
|
|
# Update cache with actual values
|
|
|
|
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
|
|
|
|
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
|
|
|
|
self.origin_len = self.kv_cache.key_cache[0].shape[-2]
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error ingesting documents: {e}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
async def update(self, new_doc: str) -> bool:
|
|
|
|
"""Update cache with new document"""
|
|
|
|
try:
|
|
|
|
if self.kv_cache is None:
|
|
|
|
return await self.ingest([new_doc])
|
|
|
|
|
|
|
|
# Clean up existing cache
|
|
|
|
self.clean_up_cache(self.kv_cache, self.origin_len)
|
|
|
|
|
|
|
|
# Add new document to cache
|
2025-04-20 16:34:29 -07:00
|
|
|
input_ids = self.tokenizer(new_doc + "\n", return_tensors="pt").input_ids.to(self.device)
|
2025-01-29 10:19:28 +05:30
|
|
|
|
|
|
|
# First run to get the cache shape
|
|
|
|
outputs = self.model(input_ids=input_ids, use_cache=True)
|
|
|
|
# Initialize cache with empty tensors of the right shape
|
|
|
|
n_layers = len(outputs.past_key_values)
|
|
|
|
batch_size = input_ids.shape[0]
|
|
|
|
|
|
|
|
# Handle different model architectures
|
|
|
|
if hasattr(self.model.config, "num_key_value_heads"):
|
|
|
|
# Models with grouped query attention (GQA) like Llama
|
|
|
|
n_kv_heads = self.model.config.num_key_value_heads
|
|
|
|
head_dim = self.model.config.head_dim
|
|
|
|
elif hasattr(self.model.config, "n_head"):
|
|
|
|
# GPT-style models
|
|
|
|
n_kv_heads = self.model.config.n_head
|
|
|
|
head_dim = self.model.config.n_embd // self.model.config.n_head
|
|
|
|
elif hasattr(self.model.config, "num_attention_heads"):
|
|
|
|
# OPT-style models
|
|
|
|
n_kv_heads = self.model.config.num_attention_heads
|
|
|
|
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported model architecture: {self.model.config.model_type}")
|
|
|
|
|
|
|
|
seq_len = input_ids.shape[1]
|
|
|
|
|
|
|
|
# Create a new cache for the update
|
|
|
|
new_cache = DynamicCache()
|
|
|
|
for i in range(n_layers):
|
|
|
|
key_shape = (batch_size, n_kv_heads, seq_len, head_dim)
|
|
|
|
value_shape = key_shape
|
|
|
|
new_cache.key_cache.append(torch.zeros(key_shape, device=self.device))
|
|
|
|
new_cache.value_cache.append(torch.zeros(value_shape, device=self.device))
|
|
|
|
|
|
|
|
# Run with the initialized cache
|
|
|
|
outputs = self.model(input_ids=input_ids, past_key_values=new_cache, use_cache=True)
|
|
|
|
# Update cache with actual values
|
|
|
|
self.kv_cache.key_cache = [layer[0] for layer in outputs.past_key_values]
|
|
|
|
self.kv_cache.value_cache = [layer[1] for layer in outputs.past_key_values]
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error updating cache: {e}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
|
|
"""Generate completion using cache-augmented generation"""
|
|
|
|
try:
|
|
|
|
if self.kv_cache is None:
|
|
|
|
raise ValueError("Cache not initialized. Please ingest documents first.")
|
|
|
|
|
|
|
|
# Clean up cache
|
|
|
|
self.clean_up_cache(self.kv_cache, self.origin_len)
|
|
|
|
|
|
|
|
# Generate completion
|
2025-04-20 16:34:29 -07:00
|
|
|
input_ids = self.tokenizer(request.query + "\n", return_tensors="pt").input_ids.to(self.device)
|
2025-01-29 10:19:28 +05:30
|
|
|
gen_ids = self.generate(input_ids, self.kv_cache, max_new_tokens=request.max_tokens)
|
|
|
|
completion = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
# Calculate token usage
|
|
|
|
usage = {
|
|
|
|
"prompt_tokens": len(input_ids[0]),
|
|
|
|
"completion_tokens": len(gen_ids[0]),
|
|
|
|
"total_tokens": len(input_ids[0]) + len(gen_ids[0]),
|
|
|
|
}
|
|
|
|
|
|
|
|
return CompletionResponse(completion=completion, usage=usage)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error generating completion: {e}")
|
|
|
|
return CompletionResponse(
|
|
|
|
completion=f"Error: {str(e)}",
|
|
|
|
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
|
|
)
|
|
|
|
|
|
|
|
def save_cache(self) -> Path:
|
|
|
|
"""Save the KV cache to disk"""
|
|
|
|
if self.kv_cache is None:
|
|
|
|
raise ValueError("No cache to save")
|
|
|
|
|
|
|
|
cache_dir = self.cache_path / "kv_cache"
|
|
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
# Save key and value caches
|
|
|
|
cache_data = {
|
|
|
|
"key_cache": self.kv_cache.key_cache,
|
|
|
|
"value_cache": self.kv_cache.value_cache,
|
|
|
|
"origin_len": self.origin_len,
|
|
|
|
}
|
|
|
|
cache_path = cache_dir / "cache.pt"
|
|
|
|
torch.save(cache_data, cache_path)
|
|
|
|
return cache_path
|
|
|
|
|
|
|
|
def load_cache(self, cache_path: Union[str, Path]) -> None:
|
|
|
|
"""Load KV cache from disk"""
|
|
|
|
cache_path = Path(cache_path)
|
|
|
|
if not cache_path.exists():
|
|
|
|
raise FileNotFoundError(f"Cache file not found at {cache_path}")
|
|
|
|
|
|
|
|
cache_data = torch.load(cache_path, map_location=self.device)
|
|
|
|
|
|
|
|
self.kv_cache = DynamicCache()
|
|
|
|
self.kv_cache.key_cache = cache_data["key_cache"]
|
|
|
|
self.kv_cache.value_cache = cache_data["value_cache"]
|
|
|
|
self.origin_len = cache_data["origin_len"]
|