Add hosted tier limits, cloud uri gen (#59)

This commit is contained in:
Adityavardhan Agrawal 2025-03-27 17:30:02 -07:00 committed by GitHub
parent f0c44cb8ea
commit 7eb5887d2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1062 additions and 4 deletions

View File

@ -11,7 +11,8 @@ import logging
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
from core.limits_utils import check_and_increment_limits
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
from core.models.completion import ChunkSource, CompletionResponse
from core.models.documents import Document, DocumentResult, ChunkResult
from core.models.graph import Graph
@ -236,6 +237,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
entity_type=EntityType(settings.dev_entity_type),
entity_id=settings.dev_entity_id,
permissions=set(settings.dev_permissions),
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
)
# Normal token verification flow
@ -255,11 +257,17 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired")
# Support both "type" and "entity_type" fields for compatibility
entity_type_field = payload.get("type") or payload.get("entity_type")
if not entity_type_field:
raise HTTPException(status_code=401, detail="Missing entity type in token")
return AuthContext(
entity_type=EntityType(payload["type"]),
entity_type=EntityType(entity_type_field),
entity_id=payload["entity_id"],
app_id=payload.get("app_id"),
permissions=set(payload.get("permissions", ["read"])),
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
@ -569,6 +577,11 @@ async def query_completion(
to enhance retrieval by finding relevant entities and their connected documents.
"""
try:
# Check query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "query", 1)
async with telemetry.track_operation(
operation_type="query",
user_id=auth.entity_id,
@ -689,6 +702,7 @@ async def update_document_text(
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_file", response_model=Document)
async def update_document_file(
document_id: str,
@ -856,6 +870,11 @@ async def create_cache(
) -> Dict[str, Any]:
"""Create a new cache with specified configuration."""
try:
# Check cache creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache", 1)
async with telemetry.track_operation(
operation_type="create_cache",
user_id=auth.entity_id,
@ -955,6 +974,11 @@ async def query_cache(
) -> CompletionResponse:
"""Query the cache with a prompt."""
try:
# Check cache query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache_query", 1)
async with telemetry.track_operation(
operation_type="query_cache",
user_id=auth.entity_id,
@ -994,6 +1018,11 @@ async def create_graph(
Graph: The created graph object
"""
try:
# Check graph creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "graph", 1)
async with telemetry.track_operation(
operation_type="create_graph",
user_id=auth.entity_id,
@ -1109,3 +1138,147 @@ async def generate_local_uri(
except Exception as e:
logger.error(f"Error generating local URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/cloud/generate_uri", include_in_schema=True)
async def generate_cloud_uri(
request: GenerateUriRequest,
authorization: str = Header(None),
) -> Dict[str, str]:
"""Generate a URI for cloud hosted applications."""
try:
app_id = request.app_id
name = request.name
user_id = request.user_id
expiry_days = request.expiry_days
logger.info(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
# Verify authorization header before proceeding
if not authorization:
logger.warning("Missing authorization header")
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify the token is valid
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode the token to ensure it's valid
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow users to create apps for themselves (or admin)
token_user_id = payload.get("user_id")
logger.debug(f"Token user ID: {token_user_id}")
logger.debug(f"User ID: {user_id}")
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
raise HTTPException(
status_code=403,
detail="You can only create apps for your own account unless you have admin permissions"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Import UserService here to avoid circular imports
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service if needed
await user_service.initialize()
# Clean name
name = name.replace(" ", "_").lower()
# Check if the user is within app limit and generate URI
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
if not uri:
logger.info("Application limit reached for this account tier with user_id: %s", user_id)
raise HTTPException(
status_code=403,
detail="Application limit reached for this account tier"
)
return {"uri": uri, "app_id": app_id}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error generating cloud URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/user/upgrade", include_in_schema=True)
async def upgrade_user_tier(
user_id: str,
tier: str,
custom_limits: Optional[Dict[str, Any]] = None,
authorization: str = Header(None),
) -> Dict[str, Any]:
"""Upgrade a user to a higher tier."""
try:
# Verify admin authorization
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode token
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow admins to upgrade users
if "admin" not in payload.get("permissions", []):
raise HTTPException(
status_code=403,
detail="Admin permission required"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Validate tier
from core.models.tiers import AccountTier
try:
account_tier = AccountTier(tier)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}")
# Upgrade user
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service
await user_service.initialize()
# Update user tier
success = await user_service.update_user_tier(user_id, tier, custom_limits)
if not success:
raise HTTPException(
status_code=404,
detail="User not found or upgrade failed"
)
return {
"user_id": user_id,
"tier": tier,
"message": f"User upgraded to {tier} tier"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error upgrading user tier: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@ -91,6 +91,9 @@ class Settings(BaseSettings):
# Colpali configuration
ENABLE_COLPALI: bool
# Mode configuration
MODE: Literal["cloud", "self_hosted"] = "cloud"
# Telemetry configuration
TELEMETRY_ENABLED: bool = True
HONEYCOMB_ENABLED: bool = True
@ -291,6 +294,7 @@ def get_settings() -> Settings:
# load databridge config
databridge_config = {
"ENABLE_COLPALI": config["databridge"]["enable_colpali"],
"MODE": config["databridge"].get("mode", "cloud"), # Default to "cloud" mode
}
# load graph config

View File

@ -494,6 +494,15 @@ class PostgresDatabase(BaseDatabase):
if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
# Add user_id filter in cloud mode
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
# Filter by user_id in access_control
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
return " OR ".join(filters)

View File

@ -0,0 +1,359 @@
import json
from typing import List, Optional, Dict, Any
from datetime import datetime, UTC, timedelta
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Index, select, text
from sqlalchemy.dialects.postgresql import JSONB
logger = logging.getLogger(__name__)
Base = declarative_base()
class UserLimitsModel(Base):
"""SQLAlchemy model for user limits data."""
__tablename__ = "user_limits"
user_id = Column(String, primary_key=True)
tier = Column(String, nullable=False) # FREE, PRO, CUSTOM, SELF_HOSTED
custom_limits = Column(JSONB, nullable=True)
usage = Column(JSONB, default=dict) # Holds all usage counters
app_ids = Column(JSONB, default=list) # List of app IDs registered by this user
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
# Create indexes
__table_args__ = (Index("idx_user_tier", "tier"),)
class UserLimitsDatabase:
"""Database operations for user limits."""
def __init__(self, uri: str):
"""Initialize database connection."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self) -> bool:
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing user limits database tables...")
# Create tables if they don't exist
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self._initialized = True
logger.info("User limits database tables initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize user limits database: {e}")
return False
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""
Get user limits for a user.
Args:
user_id: The user ID to get limits for
Returns:
Dict with user limits if found, None otherwise
"""
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return None
return {
"user_id": user_limits.user_id,
"tier": user_limits.tier,
"custom_limits": user_limits.custom_limits,
"usage": user_limits.usage,
"app_ids": user_limits.app_ids,
"created_at": user_limits.created_at,
"updated_at": user_limits.updated_at,
}
async def create_user_limits(self, user_id: str, tier: str = "free") -> bool:
"""
Create user limits record.
Args:
user_id: The user ID
tier: Initial tier (defaults to "free")
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# Check if already exists
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
if result.scalars().first():
return True # Already exists
# Create new record with properly initialized JSONB columns
# Create JSON strings and parse them for consistency
usage_json = json.dumps({
"storage_file_count": 0,
"storage_size_bytes": 0,
"hourly_query_count": 0,
"hourly_query_reset": now,
"monthly_query_count": 0,
"monthly_query_reset": now,
"hourly_ingest_count": 0,
"hourly_ingest_reset": now,
"monthly_ingest_count": 0,
"monthly_ingest_reset": now,
"graph_count": 0,
"cache_count": 0,
})
app_ids_json = json.dumps([]) # Empty array but as JSON string
# Create the model with the JSON parsed
user_limits = UserLimitsModel(
user_id=user_id,
tier=tier,
usage=json.loads(usage_json),
app_ids=json.loads(app_ids_json),
created_at=now,
updated_at=now,
)
session.add(user_limits)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to create user limits: {e}")
return False
async def update_user_tier(
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
) -> bool:
"""
Update user tier and custom limits.
Args:
user_id: The user ID
tier: New tier
custom_limits: Optional custom limits for CUSTOM tier
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return False
user_limits.tier = tier
user_limits.custom_limits = custom_limits
user_limits.updated_at = now
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update user tier: {e}")
return False
async def register_app(self, user_id: str, app_id: str) -> bool:
"""
Register an app for a user.
Args:
user_id: The user ID
app_id: The app ID to register
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# First check if user exists
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
logger.error(f"User {user_id} not found in register_app")
return False
# Use raw SQL with jsonb_array_append to update the app_ids array
# This is the most reliable way to append to a JSONB array in PostgreSQL
query = text(
"""
UPDATE user_limits
SET
app_ids = CASE
WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array
THEN app_ids || :app_id_json -- Append it if not present
ELSE app_ids -- Keep it unchanged if already present
END,
updated_at = :now
WHERE user_id = :user_id
RETURNING app_ids;
"""
)
# Execute the query
result = await session.execute(
query,
{
"app_id": app_id, # For the check
"app_id_json": f'["{app_id}"]', # JSON array format for appending
"now": now,
"user_id": user_id
}
)
# Log the result for debugging
updated_app_ids = result.scalar()
logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}")
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to register app: {e}")
return False
async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
"""
Update usage counter for a user.
Args:
user_id: The user ID
usage_type: Type of usage to update
increment: Value to increment by
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC)
now_iso = now.isoformat()
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return False
# Create a new dictionary to force SQLAlchemy to detect the change
usage = dict(user_limits.usage) if user_limits.usage else {}
# Handle different usage types
if usage_type == "query":
# Check hourly reset
hourly_reset_str = usage.get("hourly_query_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
else:
usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment
else:
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
# Check monthly reset
monthly_reset_str = usage.get("monthly_query_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
else:
usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment
else:
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
elif usage_type == "ingest":
# Similar pattern for ingest
hourly_reset_str = usage.get("hourly_ingest_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
else:
usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment
else:
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
monthly_reset_str = usage.get("monthly_ingest_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
else:
usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment
else:
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
elif usage_type == "storage_file":
usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment
elif usage_type == "storage_size":
usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment
elif usage_type == "graph":
usage["graph_count"] = usage.get("graph_count", 0) + increment
elif usage_type == "cache":
usage["cache_count"] = usage.get("cache_count", 0) + increment
# Force SQLAlchemy to recognize the change by assigning a new dict
user_limits.usage = usage
user_limits.updated_at = now_iso
# Explicitly mark as modified
session.add(user_limits)
# Log the updated usage for debugging
logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}")
logger.info(f"New usage values: {usage}")
logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}")
# Commit and flush to ensure changes are written
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update usage: {e}")
import traceback
logger.error(traceback.format_exc())
return False

71
core/limits_utils.py Normal file
View File

@ -0,0 +1,71 @@
import logging
# Initialize logger
logger = logging.getLogger(__name__)
async def check_and_increment_limits(auth, limit_type: str, value: int = 1) -> None:
"""
Check if the user is within limits for an operation and increment usage.
Args:
auth: Authentication context with user_id
limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache)
value: Value to check against limit (e.g., file size for storage_size)
Raises:
HTTPException: If the user exceeds limits
"""
from fastapi import HTTPException
from core.config import get_settings
settings = get_settings()
# Skip limit checking in self-hosted mode
if settings.MODE == "self_hosted":
return
# Check if user_id is available
if not auth.user_id:
logger.warning("User ID not available in auth context, skipping limit check")
return
# Initialize user service
from core.services.user_service import UserService
user_service = UserService()
await user_service.initialize()
# Check if user is within limits
within_limits = await user_service.check_limit(auth.user_id, limit_type, value)
if not within_limits:
# Get tier information for better error message
user_data = await user_service.get_user_limits(auth.user_id)
tier = user_data.get("tier", "unknown") if user_data else "unknown"
# Map limit types to appropriate error messages
limit_type_messages = {
"query": f"Query limit exceeded for your {tier} tier. Please upgrade or try again later.",
"ingest": f"Ingest limit exceeded for your {tier} tier. Please upgrade or try again later.",
"storage_file": f"Storage file count limit exceeded for your {tier} tier. Please delete some files or upgrade.",
"storage_size": f"Storage size limit exceeded for your {tier} tier. Please delete some files or upgrade.",
"graph": f"Graph creation limit exceeded for your {tier} tier. Please upgrade to create more graphs.",
"cache": f"Cache creation limit exceeded for your {tier} tier. Please upgrade to create more caches.",
"cache_query": f"Cache query limit exceeded for your {tier} tier. Please upgrade or try again later.",
}
# Get message for the limit type or use default message
detail = limit_type_messages.get(
limit_type, f"Limit exceeded for your {tier} tier. Please upgrade or contact support."
)
# Raise the exception with appropriate message
raise HTTPException(status_code=429, detail=detail)
# Record usage asynchronously
try:
await user_service.record_usage(auth.user_id, limit_type, value)
except Exception as e:
# Just log if recording usage fails, don't fail the operation
logger.error(f"Failed to record usage: {e}")

View File

@ -16,3 +16,4 @@ class AuthContext(BaseModel):
app_id: Optional[str] = None # uuid, only for developers
# TODO: remove permissions, not required here.
permissions: Set[str] = {"read"}
user_id: Optional[str] = None # ID of the user who owns the app/entity

View File

@ -57,3 +57,11 @@ class BatchIngestResponse(BaseModel):
"""Response model for batch ingestion"""
documents: List[Document]
errors: List[Dict[str, str]]
class GenerateUriRequest(BaseModel):
"""Request model for generating a cloud URI"""
app_id: str = Field(..., description="ID of the application")
name: str = Field(..., description="Name of the application")
user_id: str = Field(..., description="ID of the user who owns the app")
expiry_days: int = Field(default=30, description="Number of days until the token expires")

117
core/models/tiers.py Normal file
View File

@ -0,0 +1,117 @@
from enum import Enum
from typing import Dict, Any
class AccountTier(str, Enum):
"""Available account tiers."""
FREE = "free"
PRO = "pro"
CUSTOM = "custom"
SELF_HOSTED = "self_hosted"
# Tier limits definition - organized by API endpoint usage
TIER_LIMITS = {
AccountTier.FREE: {
# Application limits
"app_limit": 1, # Maximum number of applications
# Storage limits
"storage_file_limit": 100, # Maximum number of files in storage
"storage_size_limit_gb": 1, # Maximum storage size in GB
"hourly_ingest_limit": 30, # Maximum file/text ingests per hour
"monthly_ingest_limit": 100, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 30, # Maximum queries per hour
"monthly_query_limit": 1000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 3, # Maximum number of graphs
"hourly_graph_query_limit": 20, # Maximum graph queries per hour
"monthly_graph_query_limit": 200, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 0, # Maximum number of caches
"hourly_cache_query_limit": 0, # Maximum cache queries per hour
"monthly_cache_query_limit": 0, # Maximum cache queries per month
},
AccountTier.PRO: {
# Application limits
"app_limit": 5, # Maximum number of applications
# Storage limits
"storage_file_limit": 1000, # Maximum number of files in storage
"storage_size_limit_gb": 10, # Maximum storage size in GB
"hourly_ingest_limit": 100, # Maximum file/text ingests per hour
"monthly_ingest_limit": 3000, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 100, # Maximum queries per hour
"monthly_query_limit": 10000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 10, # Maximum number of graphs
"hourly_graph_query_limit": 50, # Maximum graph queries per hour
"monthly_graph_query_limit": 1000, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 5, # Maximum number of caches
"hourly_cache_query_limit": 200, # Maximum cache queries per hour
"monthly_cache_query_limit": 5000, # Maximum cache queries per month
},
AccountTier.CUSTOM: {
# Custom tier limits are set on a per-account basis
# These are default values that will be overridden
# Application limits
"app_limit": 10, # Maximum number of applications
# Storage limits
"storage_file_limit": 10000, # Maximum number of files in storage
"storage_size_limit_gb": 100, # Maximum storage size in GB
"hourly_ingest_limit": 500, # Maximum file/text ingests per hour
"monthly_ingest_limit": 15000, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 500, # Maximum queries per hour
"monthly_query_limit": 50000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 50, # Maximum number of graphs
"hourly_graph_query_limit": 200, # Maximum graph queries per hour
"monthly_graph_query_limit": 10000, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 20, # Maximum number of caches
"hourly_cache_query_limit": 1000, # Maximum cache queries per hour
"monthly_cache_query_limit": 50000, # Maximum cache queries per month
},
AccountTier.SELF_HOSTED: {
# Self-hosted has no limits
# Application limits
"app_limit": float("inf"), # Maximum number of applications
# Storage limits
"storage_file_limit": float("inf"), # Maximum number of files in storage
"storage_size_limit_gb": float("inf"), # Maximum storage size in GB
"hourly_ingest_limit": float("inf"), # Maximum file/text ingests per hour
"monthly_ingest_limit": float("inf"), # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": float("inf"), # Maximum queries per hour
"monthly_query_limit": float("inf"), # Maximum queries per month
# Graph limits
"graph_creation_limit": float("inf"), # Maximum number of graphs
"hourly_graph_query_limit": float("inf"), # Maximum graph queries per hour
"monthly_graph_query_limit": float("inf"), # Maximum graph queries per month
# Cache limits
"cache_creation_limit": float("inf"), # Maximum number of caches
"hourly_cache_query_limit": float("inf"), # Maximum cache queries per hour
"monthly_cache_query_limit": float("inf"), # Maximum cache queries per month
},
}
def get_tier_limits(tier: AccountTier, custom_limits: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Get limits for a specific account tier.
Args:
tier: The account tier
custom_limits: Optional custom limits for CUSTOM tier
Returns:
Dict of limits for the specified tier
"""
if tier == AccountTier.CUSTOM and custom_limits:
# Merge default custom limits with the provided custom limits
return {**TIER_LIMITS[tier], **custom_limits}
return TIER_LIMITS[tier]

View File

@ -0,0 +1,55 @@
from typing import Dict, Any, Optional
from datetime import datetime, UTC
from pydantic import BaseModel, Field
from .tiers import AccountTier
class UserUsage(BaseModel):
"""Tracks user's actual usage of the system."""
# Storage usage
storage_file_count: int = 0
storage_size_bytes: int = 0
# Query usage - hourly
hourly_query_count: int = 0
hourly_query_reset: Optional[datetime] = None
# Query usage - monthly
monthly_query_count: int = 0
monthly_query_reset: Optional[datetime] = None
# Ingest usage - hourly
hourly_ingest_count: int = 0
hourly_ingest_reset: Optional[datetime] = None
# Ingest usage - monthly
monthly_ingest_count: int = 0
monthly_ingest_reset: Optional[datetime] = None
# Graph usage
graph_count: int = 0
hourly_graph_query_count: int = 0
hourly_graph_query_reset: Optional[datetime] = None
monthly_graph_query_count: int = 0
monthly_graph_query_reset: Optional[datetime] = None
# Cache usage
cache_count: int = 0
hourly_cache_query_count: int = 0
hourly_cache_query_reset: Optional[datetime] = None
monthly_cache_query_count: int = 0
monthly_cache_query_reset: Optional[datetime] = None
class UserLimits(BaseModel):
"""Stores user tier and usage information."""
user_id: str
tier: AccountTier = AccountTier.FREE
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
usage: UserUsage = Field(default_factory=UserUsage)
custom_limits: Optional[Dict[str, Any]] = None
app_ids: list[str] = Field(default_factory=list)

View File

@ -329,6 +329,15 @@ class DocumentService:
logger.error(f"User {auth.entity_id} does not have write permission")
raise PermissionError("User does not have write permission")
# First check ingest limits if in cloud mode
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
from core.api import check_and_increment_limits
await check_and_increment_limits(auth, "ingest", 1)
doc = Document(
content_type="text/plain",
filename=filename,
@ -338,6 +347,7 @@ class DocumentService:
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
},
)
logger.info(f"Created text document record with ID {doc.external_id}")
@ -404,6 +414,20 @@ class DocumentService:
# Read file content
file_content = await file.read()
file_size = len(file_content) # Get file size in bytes for limit checking
# Check limits before doing any expensive processing
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding with parsing
from core.api import check_and_increment_limits
await check_and_increment_limits(auth, "ingest", 1)
await check_and_increment_limits(auth, "storage_file", 1)
await check_and_increment_limits(auth, "storage_size", file_size)
# Now proceed with parsing and processing the file
file_type = filetype.guess(file_content)
# Set default mime type for cases where filetype.guess returns None
@ -438,8 +462,7 @@ class DocumentService:
if modified_text:
text = modified_text
logger.info("Updated text with modified content from rules")
# Create document record
doc = Document(
content_type=mime_type,
filename=file.filename,
@ -449,6 +472,7 @@ class DocumentService:
"readers": [auth.entity_id],
"writers": [auth.entity_id],
"admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
},
additional_metadata=additional_metadata,
)

View File

@ -0,0 +1,236 @@
import logging
from typing import Dict, Any, Optional
from datetime import datetime, UTC, timedelta
import jwt
from ..models.tiers import AccountTier, get_tier_limits
from ..models.auth import AuthContext
from ..config import get_settings
from ..database.user_limits_db import UserLimitsDatabase
logger = logging.getLogger(__name__)
class UserService:
"""Service for managing user limits and usage."""
def __init__(self):
"""Initialize the UserService."""
self.settings = get_settings()
self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI)
async def initialize(self) -> bool:
"""Initialize database tables."""
return await self.db.initialize()
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user limits information."""
return await self.db.get_user_limits(user_id)
async def create_user(self, user_id: str) -> bool:
"""Create a new user with FREE tier."""
return await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
async def update_user_tier(
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
) -> bool:
"""Update user tier and custom limits."""
return await self.db.update_user_tier(user_id, tier, custom_limits)
async def register_app(self, user_id: str, app_id: str) -> bool:
"""
Register an app for a user.
Creates user limits record if it doesn't exist.
"""
# First check if user limits exist
user_limits = await self.db.get_user_limits(user_id)
# If user limits don't exist, create them first
if not user_limits:
logger.info(f"Creating user limits for user {user_id}")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
# Now register the app
return await self.db.register_app(user_id, app_id)
async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool:
"""
Check if a user's operation is within limits when given value is considered.
Args:
user_id: The user ID to check
limit_type: Type of limit (query, ingest, graph, cache, etc.)
value: Value to check (e.g., file size for storage)
Returns:
True if within limits, False if exceeded
"""
# Skip limit checking for self-hosted mode
if self.settings.MODE == "self_hosted":
return True
# Get user limits
user_data = await self.db.get_user_limits(user_id)
if not user_data:
# Create user limits if they don't exist
logger.info(f"User {user_id} not found when checking limits - creating limits record")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
# Fetch the newly created limits
user_data = await self.db.get_user_limits(user_id)
if not user_data:
logger.error(f"Failed to retrieve newly created user limits for user {user_id}")
return False
# Get tier limits
tier = user_data.get("tier", AccountTier.FREE)
tier_limits = get_tier_limits(tier, user_data.get("custom_limits"))
# Get current usage
usage = user_data.get("usage", {})
# Check specific limit type
if limit_type == "query":
hourly_limit = tier_limits.get("hourly_query_limit", 0)
monthly_limit = tier_limits.get("monthly_query_limit", 0)
hourly_usage = usage.get("hourly_query_count", 0)
monthly_usage = usage.get("monthly_query_count", 0)
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
elif limit_type == "ingest":
hourly_limit = tier_limits.get("hourly_ingest_limit", 0)
monthly_limit = tier_limits.get("monthly_ingest_limit", 0)
hourly_usage = usage.get("hourly_ingest_count", 0)
monthly_usage = usage.get("monthly_ingest_count", 0)
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
elif limit_type == "storage_file":
file_limit = tier_limits.get("storage_file_limit", 0)
file_count = usage.get("storage_file_count", 0)
return file_count + value <= file_limit
elif limit_type == "storage_size":
size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024
size_usage = usage.get("storage_size_bytes", 0)
return size_usage + value <= size_limit_bytes
elif limit_type == "graph":
graph_limit = tier_limits.get("graph_creation_limit", 0)
graph_count = usage.get("graph_count", 0)
return graph_count + value <= graph_limit
elif limit_type == "cache":
cache_limit = tier_limits.get("cache_creation_limit", 0)
cache_count = usage.get("cache_count", 0)
return cache_count + value <= cache_limit
return True
async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
"""
Record usage for a user.
Args:
user_id: The user ID
usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.)
increment: Value to increment by
Returns:
True if successful, False otherwise
"""
# Skip usage recording for self-hosted mode
if self.settings.MODE == "self_hosted":
return True
# Check if user limits exist, create if they don't
user_data = await self.db.get_user_limits(user_id)
if not user_data:
logger.info(f"Creating user limits for user {user_id} during usage recording")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
return await self.db.update_usage(user_id, usage_type, increment)
async def generate_cloud_uri(
self, user_id: str, app_id: str, name: str, expiry_days: int = 30
) -> Optional[str]:
"""
Generate a cloud URI for an app.
Args:
user_id: The user ID
app_id: The app ID
name: App name for display purposes
expiry_days: Number of days until token expires
Returns:
URI string with embedded token, or None if failed
"""
# Get user limits to check app limit
user_limits = await self.db.get_user_limits(user_id)
# If user doesn't exist yet, create them
if not user_limits:
await self.create_user(user_id)
user_limits = await self.db.get_user_limits(user_id)
if not user_limits:
logger.error(f"Failed to create user limits for user {user_id}")
return None
# Get tier limits to enforce app limit
tier = user_limits.get("tier", AccountTier.FREE)
tier_limits = get_tier_limits(tier, user_limits.get("custom_limits"))
app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified
current_apps = user_limits.get("app_ids", [])
# Skip the limit check if app is already registered
if app_id not in current_apps:
# Check if user has reached app limit
if len(current_apps) >= app_limit:
logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}")
return None
# Register the app
success = await self.register_app(user_id, app_id)
if not success:
logger.info(f"Failed to register app {app_id} for user {user_id}")
return None
# Create token payload
payload = {
"user_id": user_id,
"app_id": app_id,
"name": name,
"permissions": ["read", "write"],
"exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()),
"type": "developer",
"entity_id": user_id,
}
# Generate token
token = jwt.encode(
payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM
)
# Generate URI with API domain
api_domain = getattr(self.settings, "API_DOMAIN", "api.databridge.ai")
uri = f"databridge://{name}:{token}@{api_domain}/{app_id}"
return uri

View File

@ -100,6 +100,7 @@ batch_size = 4096
[databridge]
enable_colpali = true
mode = "self_hosted" # "cloud" or "self_hosted"
[graph]
provider = "ollama"