Add hosted tier limits, cloud uri gen (#59)

This commit is contained in:
Adityavardhan Agrawal 2025-03-27 17:30:02 -07:00 committed by GitHub
parent f0c44cb8ea
commit 7eb5887d2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1062 additions and 4 deletions

View File

@ -11,7 +11,8 @@ import logging
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from core.completion.openai_completion import OpenAICompletionModel from core.completion.openai_completion import OpenAICompletionModel
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse from core.limits_utils import check_and_increment_limits
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
from core.models.completion import ChunkSource, CompletionResponse from core.models.completion import ChunkSource, CompletionResponse
from core.models.documents import Document, DocumentResult, ChunkResult from core.models.documents import Document, DocumentResult, ChunkResult
from core.models.graph import Graph from core.models.graph import Graph
@ -236,6 +237,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
entity_type=EntityType(settings.dev_entity_type), entity_type=EntityType(settings.dev_entity_type),
entity_id=settings.dev_entity_id, entity_id=settings.dev_entity_id,
permissions=set(settings.dev_permissions), permissions=set(settings.dev_permissions),
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
) )
# Normal token verification flow # Normal token verification flow
@ -255,11 +257,17 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC): if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired") raise HTTPException(status_code=401, detail="Token expired")
# Support both "type" and "entity_type" fields for compatibility
entity_type_field = payload.get("type") or payload.get("entity_type")
if not entity_type_field:
raise HTTPException(status_code=401, detail="Missing entity type in token")
return AuthContext( return AuthContext(
entity_type=EntityType(payload["type"]), entity_type=EntityType(entity_type_field),
entity_id=payload["entity_id"], entity_id=payload["entity_id"],
app_id=payload.get("app_id"), app_id=payload.get("app_id"),
permissions=set(payload.get("permissions", ["read"])), permissions=set(payload.get("permissions", ["read"])),
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
) )
except jwt.InvalidTokenError as e: except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e)) raise HTTPException(status_code=401, detail=str(e))
@ -569,6 +577,11 @@ async def query_completion(
to enhance retrieval by finding relevant entities and their connected documents. to enhance retrieval by finding relevant entities and their connected documents.
""" """
try: try:
# Check query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "query", 1)
async with telemetry.track_operation( async with telemetry.track_operation(
operation_type="query", operation_type="query",
user_id=auth.entity_id, user_id=auth.entity_id,
@ -689,6 +702,7 @@ async def update_document_text(
except PermissionError as e: except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e)) raise HTTPException(status_code=403, detail=str(e))
@app.post("/documents/{document_id}/update_file", response_model=Document) @app.post("/documents/{document_id}/update_file", response_model=Document)
async def update_document_file( async def update_document_file(
document_id: str, document_id: str,
@ -856,6 +870,11 @@ async def create_cache(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create a new cache with specified configuration.""" """Create a new cache with specified configuration."""
try: try:
# Check cache creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache", 1)
async with telemetry.track_operation( async with telemetry.track_operation(
operation_type="create_cache", operation_type="create_cache",
user_id=auth.entity_id, user_id=auth.entity_id,
@ -955,6 +974,11 @@ async def query_cache(
) -> CompletionResponse: ) -> CompletionResponse:
"""Query the cache with a prompt.""" """Query the cache with a prompt."""
try: try:
# Check cache query limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "cache_query", 1)
async with telemetry.track_operation( async with telemetry.track_operation(
operation_type="query_cache", operation_type="query_cache",
user_id=auth.entity_id, user_id=auth.entity_id,
@ -994,6 +1018,11 @@ async def create_graph(
Graph: The created graph object Graph: The created graph object
""" """
try: try:
# Check graph creation limits if in cloud mode
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
await check_and_increment_limits(auth, "graph", 1)
async with telemetry.track_operation( async with telemetry.track_operation(
operation_type="create_graph", operation_type="create_graph",
user_id=auth.entity_id, user_id=auth.entity_id,
@ -1109,3 +1138,147 @@ async def generate_local_uri(
except Exception as e: except Exception as e:
logger.error(f"Error generating local URI: {e}") logger.error(f"Error generating local URI: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/cloud/generate_uri", include_in_schema=True)
async def generate_cloud_uri(
request: GenerateUriRequest,
authorization: str = Header(None),
) -> Dict[str, str]:
"""Generate a URI for cloud hosted applications."""
try:
app_id = request.app_id
name = request.name
user_id = request.user_id
expiry_days = request.expiry_days
logger.info(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
# Verify authorization header before proceeding
if not authorization:
logger.warning("Missing authorization header")
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
# Verify the token is valid
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode the token to ensure it's valid
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow users to create apps for themselves (or admin)
token_user_id = payload.get("user_id")
logger.debug(f"Token user ID: {token_user_id}")
logger.debug(f"User ID: {user_id}")
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
raise HTTPException(
status_code=403,
detail="You can only create apps for your own account unless you have admin permissions"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Import UserService here to avoid circular imports
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service if needed
await user_service.initialize()
# Clean name
name = name.replace(" ", "_").lower()
# Check if the user is within app limit and generate URI
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
if not uri:
logger.info("Application limit reached for this account tier with user_id: %s", user_id)
raise HTTPException(
status_code=403,
detail="Application limit reached for this account tier"
)
return {"uri": uri, "app_id": app_id}
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error generating cloud URI: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/user/upgrade", include_in_schema=True)
async def upgrade_user_tier(
user_id: str,
tier: str,
custom_limits: Optional[Dict[str, Any]] = None,
authorization: str = Header(None),
) -> Dict[str, Any]:
"""Upgrade a user to a higher tier."""
try:
# Verify admin authorization
if not authorization:
raise HTTPException(
status_code=401,
detail="Missing authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
try:
# Decode token
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
# Only allow admins to upgrade users
if "admin" not in payload.get("permissions", []):
raise HTTPException(
status_code=403,
detail="Admin permission required"
)
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=str(e))
# Validate tier
from core.models.tiers import AccountTier
try:
account_tier = AccountTier(tier)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}")
# Upgrade user
from core.services.user_service import UserService
user_service = UserService()
# Initialize user service
await user_service.initialize()
# Update user tier
success = await user_service.update_user_tier(user_id, tier, custom_limits)
if not success:
raise HTTPException(
status_code=404,
detail="User not found or upgrade failed"
)
return {
"user_id": user_id,
"tier": tier,
"message": f"User upgraded to {tier} tier"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error upgrading user tier: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@ -91,6 +91,9 @@ class Settings(BaseSettings):
# Colpali configuration # Colpali configuration
ENABLE_COLPALI: bool ENABLE_COLPALI: bool
# Mode configuration
MODE: Literal["cloud", "self_hosted"] = "cloud"
# Telemetry configuration # Telemetry configuration
TELEMETRY_ENABLED: bool = True TELEMETRY_ENABLED: bool = True
HONEYCOMB_ENABLED: bool = True HONEYCOMB_ENABLED: bool = True
@ -291,6 +294,7 @@ def get_settings() -> Settings:
# load databridge config # load databridge config
databridge_config = { databridge_config = {
"ENABLE_COLPALI": config["databridge"]["enable_colpali"], "ENABLE_COLPALI": config["databridge"]["enable_colpali"],
"MODE": config["databridge"].get("mode", "cloud"), # Default to "cloud" mode
} }
# load graph config # load graph config

View File

@ -494,6 +494,15 @@ class PostgresDatabase(BaseDatabase):
if auth.entity_type == "DEVELOPER" and auth.app_id: if auth.entity_type == "DEVELOPER" and auth.app_id:
# Add app-specific access for developers # Add app-specific access for developers
filters.append(f"access_control->'app_access' ? '{auth.app_id}'") filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
# Add user_id filter in cloud mode
if auth.user_id:
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud":
# Filter by user_id in access_control
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
return " OR ".join(filters) return " OR ".join(filters)

View File

@ -0,0 +1,359 @@
import json
from typing import List, Optional, Dict, Any
from datetime import datetime, UTC, timedelta
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, String, Index, select, text
from sqlalchemy.dialects.postgresql import JSONB
logger = logging.getLogger(__name__)
Base = declarative_base()
class UserLimitsModel(Base):
"""SQLAlchemy model for user limits data."""
__tablename__ = "user_limits"
user_id = Column(String, primary_key=True)
tier = Column(String, nullable=False) # FREE, PRO, CUSTOM, SELF_HOSTED
custom_limits = Column(JSONB, nullable=True)
usage = Column(JSONB, default=dict) # Holds all usage counters
app_ids = Column(JSONB, default=list) # List of app IDs registered by this user
created_at = Column(String) # ISO format string
updated_at = Column(String) # ISO format string
# Create indexes
__table_args__ = (Index("idx_user_tier", "tier"),)
class UserLimitsDatabase:
"""Database operations for user limits."""
def __init__(self, uri: str):
"""Initialize database connection."""
self.engine = create_async_engine(uri)
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
self._initialized = False
async def initialize(self) -> bool:
"""Initialize database tables and indexes."""
if self._initialized:
return True
try:
logger.info("Initializing user limits database tables...")
# Create tables if they don't exist
async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
self._initialized = True
logger.info("User limits database tables initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize user limits database: {e}")
return False
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""
Get user limits for a user.
Args:
user_id: The user ID to get limits for
Returns:
Dict with user limits if found, None otherwise
"""
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return None
return {
"user_id": user_limits.user_id,
"tier": user_limits.tier,
"custom_limits": user_limits.custom_limits,
"usage": user_limits.usage,
"app_ids": user_limits.app_ids,
"created_at": user_limits.created_at,
"updated_at": user_limits.updated_at,
}
async def create_user_limits(self, user_id: str, tier: str = "free") -> bool:
"""
Create user limits record.
Args:
user_id: The user ID
tier: Initial tier (defaults to "free")
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# Check if already exists
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
if result.scalars().first():
return True # Already exists
# Create new record with properly initialized JSONB columns
# Create JSON strings and parse them for consistency
usage_json = json.dumps({
"storage_file_count": 0,
"storage_size_bytes": 0,
"hourly_query_count": 0,
"hourly_query_reset": now,
"monthly_query_count": 0,
"monthly_query_reset": now,
"hourly_ingest_count": 0,
"hourly_ingest_reset": now,
"monthly_ingest_count": 0,
"monthly_ingest_reset": now,
"graph_count": 0,
"cache_count": 0,
})
app_ids_json = json.dumps([]) # Empty array but as JSON string
# Create the model with the JSON parsed
user_limits = UserLimitsModel(
user_id=user_id,
tier=tier,
usage=json.loads(usage_json),
app_ids=json.loads(app_ids_json),
created_at=now,
updated_at=now,
)
session.add(user_limits)
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to create user limits: {e}")
return False
async def update_user_tier(
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
) -> bool:
"""
Update user tier and custom limits.
Args:
user_id: The user ID
tier: New tier
custom_limits: Optional custom limits for CUSTOM tier
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return False
user_limits.tier = tier
user_limits.custom_limits = custom_limits
user_limits.updated_at = now
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update user tier: {e}")
return False
async def register_app(self, user_id: str, app_id: str) -> bool:
"""
Register an app for a user.
Args:
user_id: The user ID
app_id: The app ID to register
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC).isoformat()
async with self.async_session() as session:
# First check if user exists
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
logger.error(f"User {user_id} not found in register_app")
return False
# Use raw SQL with jsonb_array_append to update the app_ids array
# This is the most reliable way to append to a JSONB array in PostgreSQL
query = text(
"""
UPDATE user_limits
SET
app_ids = CASE
WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array
THEN app_ids || :app_id_json -- Append it if not present
ELSE app_ids -- Keep it unchanged if already present
END,
updated_at = :now
WHERE user_id = :user_id
RETURNING app_ids;
"""
)
# Execute the query
result = await session.execute(
query,
{
"app_id": app_id, # For the check
"app_id_json": f'["{app_id}"]', # JSON array format for appending
"now": now,
"user_id": user_id
}
)
# Log the result for debugging
updated_app_ids = result.scalar()
logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}")
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to register app: {e}")
return False
async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
"""
Update usage counter for a user.
Args:
user_id: The user ID
usage_type: Type of usage to update
increment: Value to increment by
Returns:
True if successful, False otherwise
"""
try:
now = datetime.now(UTC)
now_iso = now.isoformat()
async with self.async_session() as session:
result = await session.execute(
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
)
user_limits = result.scalars().first()
if not user_limits:
return False
# Create a new dictionary to force SQLAlchemy to detect the change
usage = dict(user_limits.usage) if user_limits.usage else {}
# Handle different usage types
if usage_type == "query":
# Check hourly reset
hourly_reset_str = usage.get("hourly_query_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
else:
usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment
else:
usage["hourly_query_count"] = increment
usage["hourly_query_reset"] = now_iso
# Check monthly reset
monthly_reset_str = usage.get("monthly_query_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
else:
usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment
else:
usage["monthly_query_count"] = increment
usage["monthly_query_reset"] = now_iso
elif usage_type == "ingest":
# Similar pattern for ingest
hourly_reset_str = usage.get("hourly_ingest_reset", "")
if hourly_reset_str:
hourly_reset = datetime.fromisoformat(hourly_reset_str)
if now > hourly_reset + timedelta(hours=1):
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
else:
usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment
else:
usage["hourly_ingest_count"] = increment
usage["hourly_ingest_reset"] = now_iso
monthly_reset_str = usage.get("monthly_ingest_reset", "")
if monthly_reset_str:
monthly_reset = datetime.fromisoformat(monthly_reset_str)
if now > monthly_reset + timedelta(days=30):
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
else:
usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment
else:
usage["monthly_ingest_count"] = increment
usage["monthly_ingest_reset"] = now_iso
elif usage_type == "storage_file":
usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment
elif usage_type == "storage_size":
usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment
elif usage_type == "graph":
usage["graph_count"] = usage.get("graph_count", 0) + increment
elif usage_type == "cache":
usage["cache_count"] = usage.get("cache_count", 0) + increment
# Force SQLAlchemy to recognize the change by assigning a new dict
user_limits.usage = usage
user_limits.updated_at = now_iso
# Explicitly mark as modified
session.add(user_limits)
# Log the updated usage for debugging
logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}")
logger.info(f"New usage values: {usage}")
logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}")
# Commit and flush to ensure changes are written
await session.commit()
return True
except Exception as e:
logger.error(f"Failed to update usage: {e}")
import traceback
logger.error(traceback.format_exc())
return False

71
core/limits_utils.py Normal file
View File

@ -0,0 +1,71 @@
import logging
# Initialize logger
logger = logging.getLogger(__name__)
async def check_and_increment_limits(auth, limit_type: str, value: int = 1) -> None:
"""
Check if the user is within limits for an operation and increment usage.
Args:
auth: Authentication context with user_id
limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache)
value: Value to check against limit (e.g., file size for storage_size)
Raises:
HTTPException: If the user exceeds limits
"""
from fastapi import HTTPException
from core.config import get_settings
settings = get_settings()
# Skip limit checking in self-hosted mode
if settings.MODE == "self_hosted":
return
# Check if user_id is available
if not auth.user_id:
logger.warning("User ID not available in auth context, skipping limit check")
return
# Initialize user service
from core.services.user_service import UserService
user_service = UserService()
await user_service.initialize()
# Check if user is within limits
within_limits = await user_service.check_limit(auth.user_id, limit_type, value)
if not within_limits:
# Get tier information for better error message
user_data = await user_service.get_user_limits(auth.user_id)
tier = user_data.get("tier", "unknown") if user_data else "unknown"
# Map limit types to appropriate error messages
limit_type_messages = {
"query": f"Query limit exceeded for your {tier} tier. Please upgrade or try again later.",
"ingest": f"Ingest limit exceeded for your {tier} tier. Please upgrade or try again later.",
"storage_file": f"Storage file count limit exceeded for your {tier} tier. Please delete some files or upgrade.",
"storage_size": f"Storage size limit exceeded for your {tier} tier. Please delete some files or upgrade.",
"graph": f"Graph creation limit exceeded for your {tier} tier. Please upgrade to create more graphs.",
"cache": f"Cache creation limit exceeded for your {tier} tier. Please upgrade to create more caches.",
"cache_query": f"Cache query limit exceeded for your {tier} tier. Please upgrade or try again later.",
}
# Get message for the limit type or use default message
detail = limit_type_messages.get(
limit_type, f"Limit exceeded for your {tier} tier. Please upgrade or contact support."
)
# Raise the exception with appropriate message
raise HTTPException(status_code=429, detail=detail)
# Record usage asynchronously
try:
await user_service.record_usage(auth.user_id, limit_type, value)
except Exception as e:
# Just log if recording usage fails, don't fail the operation
logger.error(f"Failed to record usage: {e}")

View File

@ -16,3 +16,4 @@ class AuthContext(BaseModel):
app_id: Optional[str] = None # uuid, only for developers app_id: Optional[str] = None # uuid, only for developers
# TODO: remove permissions, not required here. # TODO: remove permissions, not required here.
permissions: Set[str] = {"read"} permissions: Set[str] = {"read"}
user_id: Optional[str] = None # ID of the user who owns the app/entity

View File

@ -57,3 +57,11 @@ class BatchIngestResponse(BaseModel):
"""Response model for batch ingestion""" """Response model for batch ingestion"""
documents: List[Document] documents: List[Document]
errors: List[Dict[str, str]] errors: List[Dict[str, str]]
class GenerateUriRequest(BaseModel):
"""Request model for generating a cloud URI"""
app_id: str = Field(..., description="ID of the application")
name: str = Field(..., description="Name of the application")
user_id: str = Field(..., description="ID of the user who owns the app")
expiry_days: int = Field(default=30, description="Number of days until the token expires")

117
core/models/tiers.py Normal file
View File

@ -0,0 +1,117 @@
from enum import Enum
from typing import Dict, Any
class AccountTier(str, Enum):
"""Available account tiers."""
FREE = "free"
PRO = "pro"
CUSTOM = "custom"
SELF_HOSTED = "self_hosted"
# Tier limits definition - organized by API endpoint usage
TIER_LIMITS = {
AccountTier.FREE: {
# Application limits
"app_limit": 1, # Maximum number of applications
# Storage limits
"storage_file_limit": 100, # Maximum number of files in storage
"storage_size_limit_gb": 1, # Maximum storage size in GB
"hourly_ingest_limit": 30, # Maximum file/text ingests per hour
"monthly_ingest_limit": 100, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 30, # Maximum queries per hour
"monthly_query_limit": 1000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 3, # Maximum number of graphs
"hourly_graph_query_limit": 20, # Maximum graph queries per hour
"monthly_graph_query_limit": 200, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 0, # Maximum number of caches
"hourly_cache_query_limit": 0, # Maximum cache queries per hour
"monthly_cache_query_limit": 0, # Maximum cache queries per month
},
AccountTier.PRO: {
# Application limits
"app_limit": 5, # Maximum number of applications
# Storage limits
"storage_file_limit": 1000, # Maximum number of files in storage
"storage_size_limit_gb": 10, # Maximum storage size in GB
"hourly_ingest_limit": 100, # Maximum file/text ingests per hour
"monthly_ingest_limit": 3000, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 100, # Maximum queries per hour
"monthly_query_limit": 10000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 10, # Maximum number of graphs
"hourly_graph_query_limit": 50, # Maximum graph queries per hour
"monthly_graph_query_limit": 1000, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 5, # Maximum number of caches
"hourly_cache_query_limit": 200, # Maximum cache queries per hour
"monthly_cache_query_limit": 5000, # Maximum cache queries per month
},
AccountTier.CUSTOM: {
# Custom tier limits are set on a per-account basis
# These are default values that will be overridden
# Application limits
"app_limit": 10, # Maximum number of applications
# Storage limits
"storage_file_limit": 10000, # Maximum number of files in storage
"storage_size_limit_gb": 100, # Maximum storage size in GB
"hourly_ingest_limit": 500, # Maximum file/text ingests per hour
"monthly_ingest_limit": 15000, # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": 500, # Maximum queries per hour
"monthly_query_limit": 50000, # Maximum queries per month
# Graph limits
"graph_creation_limit": 50, # Maximum number of graphs
"hourly_graph_query_limit": 200, # Maximum graph queries per hour
"monthly_graph_query_limit": 10000, # Maximum graph queries per month
# Cache limits
"cache_creation_limit": 20, # Maximum number of caches
"hourly_cache_query_limit": 1000, # Maximum cache queries per hour
"monthly_cache_query_limit": 50000, # Maximum cache queries per month
},
AccountTier.SELF_HOSTED: {
# Self-hosted has no limits
# Application limits
"app_limit": float("inf"), # Maximum number of applications
# Storage limits
"storage_file_limit": float("inf"), # Maximum number of files in storage
"storage_size_limit_gb": float("inf"), # Maximum storage size in GB
"hourly_ingest_limit": float("inf"), # Maximum file/text ingests per hour
"monthly_ingest_limit": float("inf"), # Maximum file/text ingests per month
# Query limits
"hourly_query_limit": float("inf"), # Maximum queries per hour
"monthly_query_limit": float("inf"), # Maximum queries per month
# Graph limits
"graph_creation_limit": float("inf"), # Maximum number of graphs
"hourly_graph_query_limit": float("inf"), # Maximum graph queries per hour
"monthly_graph_query_limit": float("inf"), # Maximum graph queries per month
# Cache limits
"cache_creation_limit": float("inf"), # Maximum number of caches
"hourly_cache_query_limit": float("inf"), # Maximum cache queries per hour
"monthly_cache_query_limit": float("inf"), # Maximum cache queries per month
},
}
def get_tier_limits(tier: AccountTier, custom_limits: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Get limits for a specific account tier.
Args:
tier: The account tier
custom_limits: Optional custom limits for CUSTOM tier
Returns:
Dict of limits for the specified tier
"""
if tier == AccountTier.CUSTOM and custom_limits:
# Merge default custom limits with the provided custom limits
return {**TIER_LIMITS[tier], **custom_limits}
return TIER_LIMITS[tier]

View File

@ -0,0 +1,55 @@
from typing import Dict, Any, Optional
from datetime import datetime, UTC
from pydantic import BaseModel, Field
from .tiers import AccountTier
class UserUsage(BaseModel):
"""Tracks user's actual usage of the system."""
# Storage usage
storage_file_count: int = 0
storage_size_bytes: int = 0
# Query usage - hourly
hourly_query_count: int = 0
hourly_query_reset: Optional[datetime] = None
# Query usage - monthly
monthly_query_count: int = 0
monthly_query_reset: Optional[datetime] = None
# Ingest usage - hourly
hourly_ingest_count: int = 0
hourly_ingest_reset: Optional[datetime] = None
# Ingest usage - monthly
monthly_ingest_count: int = 0
monthly_ingest_reset: Optional[datetime] = None
# Graph usage
graph_count: int = 0
hourly_graph_query_count: int = 0
hourly_graph_query_reset: Optional[datetime] = None
monthly_graph_query_count: int = 0
monthly_graph_query_reset: Optional[datetime] = None
# Cache usage
cache_count: int = 0
hourly_cache_query_count: int = 0
hourly_cache_query_reset: Optional[datetime] = None
monthly_cache_query_count: int = 0
monthly_cache_query_reset: Optional[datetime] = None
class UserLimits(BaseModel):
"""Stores user tier and usage information."""
user_id: str
tier: AccountTier = AccountTier.FREE
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
usage: UserUsage = Field(default_factory=UserUsage)
custom_limits: Optional[Dict[str, Any]] = None
app_ids: list[str] = Field(default_factory=list)

View File

@ -329,6 +329,15 @@ class DocumentService:
logger.error(f"User {auth.entity_id} does not have write permission") logger.error(f"User {auth.entity_id} does not have write permission")
raise PermissionError("User does not have write permission") raise PermissionError("User does not have write permission")
# First check ingest limits if in cloud mode
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding
from core.api import check_and_increment_limits
await check_and_increment_limits(auth, "ingest", 1)
doc = Document( doc = Document(
content_type="text/plain", content_type="text/plain",
filename=filename, filename=filename,
@ -338,6 +347,7 @@ class DocumentService:
"readers": [auth.entity_id], "readers": [auth.entity_id],
"writers": [auth.entity_id], "writers": [auth.entity_id],
"admins": [auth.entity_id], "admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
}, },
) )
logger.info(f"Created text document record with ID {doc.external_id}") logger.info(f"Created text document record with ID {doc.external_id}")
@ -404,6 +414,20 @@ class DocumentService:
# Read file content # Read file content
file_content = await file.read() file_content = await file.read()
file_size = len(file_content) # Get file size in bytes for limit checking
# Check limits before doing any expensive processing
from core.config import get_settings
settings = get_settings()
if settings.MODE == "cloud" and auth.user_id:
# Check limits before proceeding with parsing
from core.api import check_and_increment_limits
await check_and_increment_limits(auth, "ingest", 1)
await check_and_increment_limits(auth, "storage_file", 1)
await check_and_increment_limits(auth, "storage_size", file_size)
# Now proceed with parsing and processing the file
file_type = filetype.guess(file_content) file_type = filetype.guess(file_content)
# Set default mime type for cases where filetype.guess returns None # Set default mime type for cases where filetype.guess returns None
@ -438,8 +462,7 @@ class DocumentService:
if modified_text: if modified_text:
text = modified_text text = modified_text
logger.info("Updated text with modified content from rules") logger.info("Updated text with modified content from rules")
# Create document record
doc = Document( doc = Document(
content_type=mime_type, content_type=mime_type,
filename=file.filename, filename=file.filename,
@ -449,6 +472,7 @@ class DocumentService:
"readers": [auth.entity_id], "readers": [auth.entity_id],
"writers": [auth.entity_id], "writers": [auth.entity_id],
"admins": [auth.entity_id], "admins": [auth.entity_id],
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
}, },
additional_metadata=additional_metadata, additional_metadata=additional_metadata,
) )

View File

@ -0,0 +1,236 @@
import logging
from typing import Dict, Any, Optional
from datetime import datetime, UTC, timedelta
import jwt
from ..models.tiers import AccountTier, get_tier_limits
from ..models.auth import AuthContext
from ..config import get_settings
from ..database.user_limits_db import UserLimitsDatabase
logger = logging.getLogger(__name__)
class UserService:
"""Service for managing user limits and usage."""
def __init__(self):
"""Initialize the UserService."""
self.settings = get_settings()
self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI)
async def initialize(self) -> bool:
"""Initialize database tables."""
return await self.db.initialize()
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user limits information."""
return await self.db.get_user_limits(user_id)
async def create_user(self, user_id: str) -> bool:
"""Create a new user with FREE tier."""
return await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
async def update_user_tier(
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
) -> bool:
"""Update user tier and custom limits."""
return await self.db.update_user_tier(user_id, tier, custom_limits)
async def register_app(self, user_id: str, app_id: str) -> bool:
"""
Register an app for a user.
Creates user limits record if it doesn't exist.
"""
# First check if user limits exist
user_limits = await self.db.get_user_limits(user_id)
# If user limits don't exist, create them first
if not user_limits:
logger.info(f"Creating user limits for user {user_id}")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
# Now register the app
return await self.db.register_app(user_id, app_id)
async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool:
"""
Check if a user's operation is within limits when given value is considered.
Args:
user_id: The user ID to check
limit_type: Type of limit (query, ingest, graph, cache, etc.)
value: Value to check (e.g., file size for storage)
Returns:
True if within limits, False if exceeded
"""
# Skip limit checking for self-hosted mode
if self.settings.MODE == "self_hosted":
return True
# Get user limits
user_data = await self.db.get_user_limits(user_id)
if not user_data:
# Create user limits if they don't exist
logger.info(f"User {user_id} not found when checking limits - creating limits record")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
# Fetch the newly created limits
user_data = await self.db.get_user_limits(user_id)
if not user_data:
logger.error(f"Failed to retrieve newly created user limits for user {user_id}")
return False
# Get tier limits
tier = user_data.get("tier", AccountTier.FREE)
tier_limits = get_tier_limits(tier, user_data.get("custom_limits"))
# Get current usage
usage = user_data.get("usage", {})
# Check specific limit type
if limit_type == "query":
hourly_limit = tier_limits.get("hourly_query_limit", 0)
monthly_limit = tier_limits.get("monthly_query_limit", 0)
hourly_usage = usage.get("hourly_query_count", 0)
monthly_usage = usage.get("monthly_query_count", 0)
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
elif limit_type == "ingest":
hourly_limit = tier_limits.get("hourly_ingest_limit", 0)
monthly_limit = tier_limits.get("monthly_ingest_limit", 0)
hourly_usage = usage.get("hourly_ingest_count", 0)
monthly_usage = usage.get("monthly_ingest_count", 0)
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
elif limit_type == "storage_file":
file_limit = tier_limits.get("storage_file_limit", 0)
file_count = usage.get("storage_file_count", 0)
return file_count + value <= file_limit
elif limit_type == "storage_size":
size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024
size_usage = usage.get("storage_size_bytes", 0)
return size_usage + value <= size_limit_bytes
elif limit_type == "graph":
graph_limit = tier_limits.get("graph_creation_limit", 0)
graph_count = usage.get("graph_count", 0)
return graph_count + value <= graph_limit
elif limit_type == "cache":
cache_limit = tier_limits.get("cache_creation_limit", 0)
cache_count = usage.get("cache_count", 0)
return cache_count + value <= cache_limit
return True
async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
"""
Record usage for a user.
Args:
user_id: The user ID
usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.)
increment: Value to increment by
Returns:
True if successful, False otherwise
"""
# Skip usage recording for self-hosted mode
if self.settings.MODE == "self_hosted":
return True
# Check if user limits exist, create if they don't
user_data = await self.db.get_user_limits(user_id)
if not user_data:
logger.info(f"Creating user limits for user {user_id} during usage recording")
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
if not success:
logger.error(f"Failed to create user limits for user {user_id}")
return False
return await self.db.update_usage(user_id, usage_type, increment)
async def generate_cloud_uri(
self, user_id: str, app_id: str, name: str, expiry_days: int = 30
) -> Optional[str]:
"""
Generate a cloud URI for an app.
Args:
user_id: The user ID
app_id: The app ID
name: App name for display purposes
expiry_days: Number of days until token expires
Returns:
URI string with embedded token, or None if failed
"""
# Get user limits to check app limit
user_limits = await self.db.get_user_limits(user_id)
# If user doesn't exist yet, create them
if not user_limits:
await self.create_user(user_id)
user_limits = await self.db.get_user_limits(user_id)
if not user_limits:
logger.error(f"Failed to create user limits for user {user_id}")
return None
# Get tier limits to enforce app limit
tier = user_limits.get("tier", AccountTier.FREE)
tier_limits = get_tier_limits(tier, user_limits.get("custom_limits"))
app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified
current_apps = user_limits.get("app_ids", [])
# Skip the limit check if app is already registered
if app_id not in current_apps:
# Check if user has reached app limit
if len(current_apps) >= app_limit:
logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}")
return None
# Register the app
success = await self.register_app(user_id, app_id)
if not success:
logger.info(f"Failed to register app {app_id} for user {user_id}")
return None
# Create token payload
payload = {
"user_id": user_id,
"app_id": app_id,
"name": name,
"permissions": ["read", "write"],
"exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()),
"type": "developer",
"entity_id": user_id,
}
# Generate token
token = jwt.encode(
payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM
)
# Generate URI with API domain
api_domain = getattr(self.settings, "API_DOMAIN", "api.databridge.ai")
uri = f"databridge://{name}:{token}@{api_domain}/{app_id}"
return uri

View File

@ -100,6 +100,7 @@ batch_size = 4096
[databridge] [databridge]
enable_colpali = true enable_colpali = true
mode = "self_hosted" # "cloud" or "self_hosted"
[graph] [graph]
provider = "ollama" provider = "ollama"