switch to decorator pattern for telemetry (#99)

This commit is contained in:
Arnav Agrawal 2025-04-19 16:13:51 -07:00 committed by GitHub
parent 09622cc3fc
commit c56f66349e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1061 additions and 949 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Callable, TypeVar, Union, cast
from dataclasses import dataclass
import threading
from collections import defaultdict
@ -11,6 +11,7 @@ from pathlib import Path
import uuid
import hashlib
import logging
import functools
from core.config import get_settings
@ -337,6 +338,101 @@ class UsageRecord:
metadata: Optional[Dict] = None
# Type variable for function return type
T = TypeVar('T')
class MetadataField:
"""Defines a metadata field to extract and how to extract it."""
def __init__(self, key: str, source: str, attr_name: Optional[str] = None,
default: Any = None, transform: Optional[Callable[[Any], Any]] = None):
"""
Initialize a metadata field definition.
Args:
key: The key to use in the metadata dictionary
source: The source of the data ('request', 'kwargs', etc.)
attr_name: The attribute name to extract (if None, uses key)
default: Default value if not found
transform: Optional function to transform the extracted value
"""
self.key = key
self.source = source
self.attr_name = attr_name or key
self.default = default
self.transform = transform
def extract(self, args: tuple, kwargs: dict) -> Any:
"""Extract the field value from args/kwargs based on configuration."""
value = self.default
if self.source == 'kwargs':
value = kwargs.get(self.attr_name, self.default)
elif self.source == 'request':
request = kwargs.get('request')
if request:
if hasattr(request, 'get') and callable(request.get):
value = request.get(self.attr_name, self.default)
else:
value = getattr(request, self.attr_name, self.default)
if self.transform and value is not None:
value = self.transform(value)
return value
class MetadataExtractor:
"""Base class for metadata extractors with common functionality."""
def __init__(self, fields: List[MetadataField] = None):
"""Initialize with a list of field definitions."""
self.fields = fields or []
def extract(self, args: tuple, kwargs: dict) -> dict:
"""Extract metadata using the field definitions."""
metadata = {}
for field in self.fields:
value = field.extract(args, kwargs)
if value is not None: # Only include non-None values
metadata[field.key] = value
return metadata
def __call__(self, *args, **kwargs) -> dict:
"""Make the extractor callable as an instance method."""
# If called as an instance method, the first arg will be the instance
# which we don't need for extraction, so we slice it off if there are any args
actual_args = args[1:] if len(args) > 0 else ()
return self.extract(actual_args, kwargs)
# Common transforms and utilities for metadata extraction
def parse_json(value, default=None):
"""Parse a JSON string safely, returning default on error."""
if not isinstance(value, str):
return default
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return default
def get_json_type(value):
"""Determine if a JSON value is a list or single object."""
return "list" if isinstance(value, list) else "single"
def get_list_len(value, default=0):
"""Get the length of a list safely."""
if value and isinstance(value, list):
return len(value)
return default
def is_not_none(value):
"""Check if a value is not None."""
return value is not None
class TelemetryService:
_instance = None
_lock = threading.Lock()
@ -449,6 +545,234 @@ class TelemetryService:
unit="ms",
)
# Initialize metadata extractors
self._setup_metadata_extractors()
def _setup_metadata_extractors(self):
"""Set up all the metadata extractors with their field definitions."""
# Common fields that appear in many requests
common_request_fields = [
MetadataField("use_colpali", "request"),
MetadataField("folder_name", "request"),
MetadataField("end_user_id", "request"),
]
retrieval_fields = common_request_fields + [
MetadataField("k", "request"),
MetadataField("min_score", "request"),
MetadataField("use_reranking", "request"),
]
# Set up all the metadata extractors
self.ingest_text_metadata = MetadataExtractor(
common_request_fields + [
MetadataField("metadata", "request", default={}),
MetadataField("rules", "request", default=[]),
]
)
self.ingest_file_metadata = MetadataExtractor([
MetadataField("filename", "kwargs", transform=lambda file: file.filename if file else None),
MetadataField("content_type", "kwargs", transform=lambda file: file.content_type if file else None),
MetadataField("metadata", "kwargs", transform=lambda v: parse_json(v, {})),
MetadataField("rules", "kwargs", transform=lambda v: parse_json(v, [])),
MetadataField("use_colpali", "kwargs"),
MetadataField("folder_name", "kwargs"),
MetadataField("end_user_id", "kwargs"),
])
self.batch_ingest_metadata = MetadataExtractor([
MetadataField("file_count", "kwargs", "files", transform=get_list_len),
MetadataField("metadata_type", "kwargs", "metadata",
transform=lambda v: get_json_type(parse_json(v, {}))),
MetadataField("rules_type", "kwargs", "rules",
transform=lambda v: "per_file" if isinstance(parse_json(v, []), list)
and parse_json(v, []) and isinstance(parse_json(v, [])[0], list)
else "shared"),
MetadataField("folder_name", "kwargs"),
MetadataField("end_user_id", "kwargs"),
])
self.retrieve_chunks_metadata = MetadataExtractor(retrieval_fields)
self.retrieve_docs_metadata = MetadataExtractor(retrieval_fields)
self.batch_documents_metadata = MetadataExtractor([
MetadataField("document_count", "request", transform=lambda req:
len(req.get("document_ids", [])) if req else 0),
MetadataField("folder_name", "request"),
MetadataField("end_user_id", "request"),
])
self.batch_chunks_metadata = MetadataExtractor([
MetadataField("chunk_count", "request", transform=lambda req:
len(req.get("sources", [])) if req else 0),
MetadataField("folder_name", "request"),
MetadataField("end_user_id", "request"),
MetadataField("use_colpali", "request"),
])
self.query_metadata = MetadataExtractor(
retrieval_fields + [
MetadataField("max_tokens", "request"),
MetadataField("temperature", "request"),
MetadataField("graph_name", "request"),
MetadataField("hop_depth", "request"),
MetadataField("include_paths", "request"),
MetadataField("has_prompt_overrides", "request", "prompt_overrides",
transform=lambda v: v is not None),
]
)
self.document_delete_metadata = MetadataExtractor([
MetadataField("document_id", "kwargs"),
])
self.document_update_text_metadata = MetadataExtractor([
MetadataField("document_id", "kwargs"),
MetadataField("update_strategy", "kwargs", default="add"),
MetadataField("use_colpali", "request"),
MetadataField("has_filename", "request", "filename", transform=is_not_none),
])
self.document_update_file_metadata = MetadataExtractor([
MetadataField("document_id", "kwargs"),
MetadataField("update_strategy", "kwargs", default="add"),
MetadataField("use_colpali", "kwargs"),
MetadataField("filename", "kwargs", transform=lambda file: file.filename if file else None),
MetadataField("content_type", "kwargs", transform=lambda file: file.content_type if file else None),
])
self.document_update_metadata_resolver = MetadataExtractor([
MetadataField("document_id", "kwargs"),
])
self.usage_stats_metadata = MetadataExtractor([])
self.recent_usage_metadata = MetadataExtractor([
MetadataField("operation_type", "kwargs"),
MetadataField("since", "kwargs", transform=lambda dt: dt.isoformat() if dt else None),
MetadataField("status", "kwargs"),
])
self.cache_create_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
MetadataField("model", "kwargs"),
MetadataField("gguf_file", "kwargs"),
MetadataField("filters", "kwargs"),
MetadataField("docs", "kwargs"),
])
self.cache_get_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
])
self.cache_update_metadata = self.cache_get_metadata
self.cache_add_docs_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
MetadataField("docs", "kwargs"),
])
self.cache_query_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
MetadataField("query", "kwargs"),
MetadataField("max_tokens", "kwargs"),
MetadataField("temperature", "kwargs"),
])
self.create_graph_metadata = MetadataExtractor([
MetadataField("name", "request"),
MetadataField("has_filters", "request", "filters", transform=is_not_none),
MetadataField("document_count", "request", "documents",
transform=lambda docs: len(docs) if docs else 0),
MetadataField("has_prompt_overrides", "request", "prompt_overrides", transform=is_not_none),
MetadataField("folder_name", "request"),
MetadataField("end_user_id", "request"),
])
self.get_graph_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
MetadataField("folder_name", "kwargs"),
MetadataField("end_user_id", "kwargs"),
])
self.list_graphs_metadata = MetadataExtractor([
MetadataField("folder_name", "kwargs"),
MetadataField("end_user_id", "kwargs"),
])
self.update_graph_metadata = MetadataExtractor([
MetadataField("name", "kwargs"),
MetadataField("has_additional_filters", "request", "additional_filters", transform=is_not_none),
MetadataField("additional_document_count", "request", "additional_documents",
transform=lambda docs: len(docs) if docs else 0),
MetadataField("has_prompt_overrides", "request", "prompt_overrides", transform=is_not_none),
MetadataField("folder_name", "request"),
MetadataField("end_user_id", "request"),
])
self.set_folder_rule_metadata = MetadataExtractor([
MetadataField("folder_id", "kwargs"),
MetadataField("apply_to_existing", "kwargs", default=True),
MetadataField("rule_count", "request", "rules",
transform=lambda rules: len(rules) if hasattr(rules, "__len__") else 0),
MetadataField("rule_types", "request", "rules",
transform=lambda rules: [rule.type for rule in rules] if hasattr(rules, "__iter__") else []),
])
def track(self, operation_type: Optional[str] = None, metadata_resolver: Optional[Callable] = None):
"""
Decorator for tracking API operations with telemetry.
Args:
operation_type: Type of operation or function name if None
metadata_resolver: Function that extracts metadata from the request/args/kwargs
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# Extract auth from kwargs
auth = kwargs.get('auth')
if not auth:
# Try to find auth in positional arguments (unlikely, but possible)
for arg in args:
if hasattr(arg, 'entity_id') and hasattr(arg, 'permissions'):
auth = arg
break
# If we don't have auth, we can't track the operation
if not auth:
return await func(*args, **kwargs)
# Use function name if operation_type not provided
op_type = operation_type or func.__name__
# Generate metadata using resolver or create empty dict
meta = {}
if metadata_resolver:
meta = metadata_resolver(*args, **kwargs)
# Get approximate token count for text ingestion
tokens = 0
# Try to extract tokens for text ingestion
request = kwargs.get('request')
if request and hasattr(request, 'content') and isinstance(request.content, str):
tokens = len(request.content.split()) # Approximate token count
# Run the function within the telemetry context
async with self.track_operation(
operation_type=op_type,
user_id=auth.entity_id,
tokens_used=tokens,
metadata=meta,
) as span:
# Call the original function
result = await func(*args, **kwargs)
return result
return wrapper
return decorator
@asynccontextmanager
async def track_operation(
self,