diff --git a/core/api.py b/core/api.py index 704f724..cba8c6b 100644 --- a/core/api.py +++ b/core/api.py @@ -11,7 +11,8 @@ import logging from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from core.completion.openai_completion import OpenAICompletionModel from core.embedding.ollama_embedding_model import OllamaEmbeddingModel -from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse +from core.limits_utils import check_and_increment_limits +from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse from core.models.completion import ChunkSource, CompletionResponse from core.models.documents import Document, DocumentResult, ChunkResult from core.models.graph import Graph @@ -236,6 +237,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext: entity_type=EntityType(settings.dev_entity_type), entity_id=settings.dev_entity_id, permissions=set(settings.dev_permissions), + user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id ) # Normal token verification flow @@ -255,11 +257,17 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext: if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC): raise HTTPException(status_code=401, detail="Token expired") + # Support both "type" and "entity_type" fields for compatibility + entity_type_field = payload.get("type") or payload.get("entity_type") + if not entity_type_field: + raise HTTPException(status_code=401, detail="Missing entity type in token") + return AuthContext( - entity_type=EntityType(payload["type"]), + entity_type=EntityType(entity_type_field), entity_id=payload["entity_id"], app_id=payload.get("app_id"), permissions=set(payload.get("permissions", ["read"])), + user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id ) except jwt.InvalidTokenError as e: raise HTTPException(status_code=401, detail=str(e)) @@ -569,6 +577,11 @@ async def query_completion( to enhance retrieval by finding relevant entities and their connected documents. """ try: + # Check query limits if in cloud mode + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding + await check_and_increment_limits(auth, "query", 1) + async with telemetry.track_operation( operation_type="query", user_id=auth.entity_id, @@ -689,6 +702,7 @@ async def update_document_text( except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) + @app.post("/documents/{document_id}/update_file", response_model=Document) async def update_document_file( document_id: str, @@ -856,6 +870,11 @@ async def create_cache( ) -> Dict[str, Any]: """Create a new cache with specified configuration.""" try: + # Check cache creation limits if in cloud mode + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding + await check_and_increment_limits(auth, "cache", 1) + async with telemetry.track_operation( operation_type="create_cache", user_id=auth.entity_id, @@ -955,6 +974,11 @@ async def query_cache( ) -> CompletionResponse: """Query the cache with a prompt.""" try: + # Check cache query limits if in cloud mode + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding + await check_and_increment_limits(auth, "cache_query", 1) + async with telemetry.track_operation( operation_type="query_cache", user_id=auth.entity_id, @@ -994,6 +1018,11 @@ async def create_graph( Graph: The created graph object """ try: + # Check graph creation limits if in cloud mode + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding + await check_and_increment_limits(auth, "graph", 1) + async with telemetry.track_operation( operation_type="create_graph", user_id=auth.entity_id, @@ -1109,3 +1138,147 @@ async def generate_local_uri( except Exception as e: logger.error(f"Error generating local URI: {e}") raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/cloud/generate_uri", include_in_schema=True) +async def generate_cloud_uri( + request: GenerateUriRequest, + authorization: str = Header(None), +) -> Dict[str, str]: + """Generate a URI for cloud hosted applications.""" + try: + app_id = request.app_id + name = request.name + user_id = request.user_id + expiry_days = request.expiry_days + + logger.info(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}") + + # Verify authorization header before proceeding + if not authorization: + logger.warning("Missing authorization header") + raise HTTPException( + status_code=401, + detail="Missing authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Verify the token is valid + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + token = authorization[7:] # Remove "Bearer " + + try: + # Decode the token to ensure it's valid + payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + + # Only allow users to create apps for themselves (or admin) + token_user_id = payload.get("user_id") + logger.debug(f"Token user ID: {token_user_id}") + logger.debug(f"User ID: {user_id}") + if not (token_user_id == user_id or "admin" in payload.get("permissions", [])): + raise HTTPException( + status_code=403, + detail="You can only create apps for your own account unless you have admin permissions" + ) + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=401, detail=str(e)) + + # Import UserService here to avoid circular imports + from core.services.user_service import UserService + user_service = UserService() + + # Initialize user service if needed + await user_service.initialize() + + # Clean name + name = name.replace(" ", "_").lower() + + # Check if the user is within app limit and generate URI + uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days) + + if not uri: + logger.info("Application limit reached for this account tier with user_id: %s", user_id) + raise HTTPException( + status_code=403, + detail="Application limit reached for this account tier" + ) + + return {"uri": uri, "app_id": app_id} + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logger.error(f"Error generating cloud URI: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/user/upgrade", include_in_schema=True) +async def upgrade_user_tier( + user_id: str, + tier: str, + custom_limits: Optional[Dict[str, Any]] = None, + authorization: str = Header(None), +) -> Dict[str, Any]: + """Upgrade a user to a higher tier.""" + try: + # Verify admin authorization + if not authorization: + raise HTTPException( + status_code=401, + detail="Missing authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + token = authorization[7:] # Remove "Bearer " + + try: + # Decode token + payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + + # Only allow admins to upgrade users + if "admin" not in payload.get("permissions", []): + raise HTTPException( + status_code=403, + detail="Admin permission required" + ) + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=401, detail=str(e)) + + # Validate tier + from core.models.tiers import AccountTier + try: + account_tier = AccountTier(tier) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}") + + # Upgrade user + from core.services.user_service import UserService + user_service = UserService() + + # Initialize user service + await user_service.initialize() + + # Update user tier + success = await user_service.update_user_tier(user_id, tier, custom_limits) + + if not success: + raise HTTPException( + status_code=404, + detail="User not found or upgrade failed" + ) + + return { + "user_id": user_id, + "tier": tier, + "message": f"User upgraded to {tier} tier" + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Error upgrading user tier: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/core/config.py b/core/config.py index 79b2e63..1b78fab 100644 --- a/core/config.py +++ b/core/config.py @@ -91,6 +91,9 @@ class Settings(BaseSettings): # Colpali configuration ENABLE_COLPALI: bool + # Mode configuration + MODE: Literal["cloud", "self_hosted"] = "cloud" + # Telemetry configuration TELEMETRY_ENABLED: bool = True HONEYCOMB_ENABLED: bool = True @@ -291,6 +294,7 @@ def get_settings() -> Settings: # load databridge config databridge_config = { "ENABLE_COLPALI": config["databridge"]["enable_colpali"], + "MODE": config["databridge"].get("mode", "cloud"), # Default to "cloud" mode } # load graph config diff --git a/core/database/postgres_database.py b/core/database/postgres_database.py index cea7617..2e0a780 100644 --- a/core/database/postgres_database.py +++ b/core/database/postgres_database.py @@ -494,6 +494,15 @@ class PostgresDatabase(BaseDatabase): if auth.entity_type == "DEVELOPER" and auth.app_id: # Add app-specific access for developers filters.append(f"access_control->'app_access' ? '{auth.app_id}'") + + # Add user_id filter in cloud mode + if auth.user_id: + from core.config import get_settings + settings = get_settings() + + if settings.MODE == "cloud": + # Filter by user_id in access_control + filters.append(f"access_control->>'user_id' = '{auth.user_id}'") return " OR ".join(filters) diff --git a/core/database/user_limits_db.py b/core/database/user_limits_db.py new file mode 100644 index 0000000..07cb756 --- /dev/null +++ b/core/database/user_limits_db.py @@ -0,0 +1,359 @@ +import json +from typing import List, Optional, Dict, Any +from datetime import datetime, UTC, timedelta +import logging +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import Column, String, Index, select, text +from sqlalchemy.dialects.postgresql import JSONB + +logger = logging.getLogger(__name__) +Base = declarative_base() + + +class UserLimitsModel(Base): + """SQLAlchemy model for user limits data.""" + + __tablename__ = "user_limits" + + user_id = Column(String, primary_key=True) + tier = Column(String, nullable=False) # FREE, PRO, CUSTOM, SELF_HOSTED + custom_limits = Column(JSONB, nullable=True) + usage = Column(JSONB, default=dict) # Holds all usage counters + app_ids = Column(JSONB, default=list) # List of app IDs registered by this user + created_at = Column(String) # ISO format string + updated_at = Column(String) # ISO format string + + # Create indexes + __table_args__ = (Index("idx_user_tier", "tier"),) + + +class UserLimitsDatabase: + """Database operations for user limits.""" + + def __init__(self, uri: str): + """Initialize database connection.""" + self.engine = create_async_engine(uri) + self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + self._initialized = False + + async def initialize(self) -> bool: + """Initialize database tables and indexes.""" + if self._initialized: + return True + + try: + logger.info("Initializing user limits database tables...") + # Create tables if they don't exist + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + self._initialized = True + logger.info("User limits database tables initialized successfully") + return True + except Exception as e: + logger.error(f"Failed to initialize user limits database: {e}") + return False + + async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]: + """ + Get user limits for a user. + + Args: + user_id: The user ID to get limits for + + Returns: + Dict with user limits if found, None otherwise + """ + async with self.async_session() as session: + result = await session.execute( + select(UserLimitsModel).where(UserLimitsModel.user_id == user_id) + ) + user_limits = result.scalars().first() + + if not user_limits: + return None + + return { + "user_id": user_limits.user_id, + "tier": user_limits.tier, + "custom_limits": user_limits.custom_limits, + "usage": user_limits.usage, + "app_ids": user_limits.app_ids, + "created_at": user_limits.created_at, + "updated_at": user_limits.updated_at, + } + + async def create_user_limits(self, user_id: str, tier: str = "free") -> bool: + """ + Create user limits record. + + Args: + user_id: The user ID + tier: Initial tier (defaults to "free") + + Returns: + True if successful, False otherwise + """ + try: + now = datetime.now(UTC).isoformat() + + async with self.async_session() as session: + # Check if already exists + result = await session.execute( + select(UserLimitsModel).where(UserLimitsModel.user_id == user_id) + ) + if result.scalars().first(): + return True # Already exists + + # Create new record with properly initialized JSONB columns + # Create JSON strings and parse them for consistency + usage_json = json.dumps({ + "storage_file_count": 0, + "storage_size_bytes": 0, + "hourly_query_count": 0, + "hourly_query_reset": now, + "monthly_query_count": 0, + "monthly_query_reset": now, + "hourly_ingest_count": 0, + "hourly_ingest_reset": now, + "monthly_ingest_count": 0, + "monthly_ingest_reset": now, + "graph_count": 0, + "cache_count": 0, + }) + app_ids_json = json.dumps([]) # Empty array but as JSON string + + # Create the model with the JSON parsed + user_limits = UserLimitsModel( + user_id=user_id, + tier=tier, + usage=json.loads(usage_json), + app_ids=json.loads(app_ids_json), + created_at=now, + updated_at=now, + ) + + session.add(user_limits) + await session.commit() + return True + except Exception as e: + logger.error(f"Failed to create user limits: {e}") + return False + + async def update_user_tier( + self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None + ) -> bool: + """ + Update user tier and custom limits. + + Args: + user_id: The user ID + tier: New tier + custom_limits: Optional custom limits for CUSTOM tier + + Returns: + True if successful, False otherwise + """ + try: + now = datetime.now(UTC).isoformat() + + async with self.async_session() as session: + result = await session.execute( + select(UserLimitsModel).where(UserLimitsModel.user_id == user_id) + ) + user_limits = result.scalars().first() + + if not user_limits: + return False + + user_limits.tier = tier + user_limits.custom_limits = custom_limits + user_limits.updated_at = now + + await session.commit() + return True + except Exception as e: + logger.error(f"Failed to update user tier: {e}") + return False + + async def register_app(self, user_id: str, app_id: str) -> bool: + """ + Register an app for a user. + + Args: + user_id: The user ID + app_id: The app ID to register + + Returns: + True if successful, False otherwise + """ + try: + now = datetime.now(UTC).isoformat() + + async with self.async_session() as session: + # First check if user exists + result = await session.execute( + select(UserLimitsModel).where(UserLimitsModel.user_id == user_id) + ) + user_limits = result.scalars().first() + + if not user_limits: + logger.error(f"User {user_id} not found in register_app") + return False + + # Use raw SQL with jsonb_array_append to update the app_ids array + # This is the most reliable way to append to a JSONB array in PostgreSQL + query = text( + """ + UPDATE user_limits + SET + app_ids = CASE + WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array + THEN app_ids || :app_id_json -- Append it if not present + ELSE app_ids -- Keep it unchanged if already present + END, + updated_at = :now + WHERE user_id = :user_id + RETURNING app_ids; + """ + ) + + # Execute the query + result = await session.execute( + query, + { + "app_id": app_id, # For the check + "app_id_json": f'["{app_id}"]', # JSON array format for appending + "now": now, + "user_id": user_id + } + ) + + # Log the result for debugging + updated_app_ids = result.scalar() + logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}") + + await session.commit() + return True + + except Exception as e: + logger.error(f"Failed to register app: {e}") + return False + + async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool: + """ + Update usage counter for a user. + + Args: + user_id: The user ID + usage_type: Type of usage to update + increment: Value to increment by + + Returns: + True if successful, False otherwise + """ + try: + now = datetime.now(UTC) + now_iso = now.isoformat() + + async with self.async_session() as session: + result = await session.execute( + select(UserLimitsModel).where(UserLimitsModel.user_id == user_id) + ) + user_limits = result.scalars().first() + + if not user_limits: + return False + + # Create a new dictionary to force SQLAlchemy to detect the change + usage = dict(user_limits.usage) if user_limits.usage else {} + + # Handle different usage types + if usage_type == "query": + # Check hourly reset + hourly_reset_str = usage.get("hourly_query_reset", "") + if hourly_reset_str: + hourly_reset = datetime.fromisoformat(hourly_reset_str) + if now > hourly_reset + timedelta(hours=1): + usage["hourly_query_count"] = increment + usage["hourly_query_reset"] = now_iso + else: + usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment + else: + usage["hourly_query_count"] = increment + usage["hourly_query_reset"] = now_iso + + # Check monthly reset + monthly_reset_str = usage.get("monthly_query_reset", "") + if monthly_reset_str: + monthly_reset = datetime.fromisoformat(monthly_reset_str) + if now > monthly_reset + timedelta(days=30): + usage["monthly_query_count"] = increment + usage["monthly_query_reset"] = now_iso + else: + usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment + else: + usage["monthly_query_count"] = increment + usage["monthly_query_reset"] = now_iso + + elif usage_type == "ingest": + # Similar pattern for ingest + hourly_reset_str = usage.get("hourly_ingest_reset", "") + if hourly_reset_str: + hourly_reset = datetime.fromisoformat(hourly_reset_str) + if now > hourly_reset + timedelta(hours=1): + usage["hourly_ingest_count"] = increment + usage["hourly_ingest_reset"] = now_iso + else: + usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment + else: + usage["hourly_ingest_count"] = increment + usage["hourly_ingest_reset"] = now_iso + + monthly_reset_str = usage.get("monthly_ingest_reset", "") + if monthly_reset_str: + monthly_reset = datetime.fromisoformat(monthly_reset_str) + if now > monthly_reset + timedelta(days=30): + usage["monthly_ingest_count"] = increment + usage["monthly_ingest_reset"] = now_iso + else: + usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment + else: + usage["monthly_ingest_count"] = increment + usage["monthly_ingest_reset"] = now_iso + + elif usage_type == "storage_file": + usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment + + elif usage_type == "storage_size": + usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment + + elif usage_type == "graph": + usage["graph_count"] = usage.get("graph_count", 0) + increment + + elif usage_type == "cache": + usage["cache_count"] = usage.get("cache_count", 0) + increment + + # Force SQLAlchemy to recognize the change by assigning a new dict + user_limits.usage = usage + user_limits.updated_at = now_iso + + # Explicitly mark as modified + session.add(user_limits) + + # Log the updated usage for debugging + logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}") + logger.info(f"New usage values: {usage}") + logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}") + + # Commit and flush to ensure changes are written + await session.commit() + + return True + + except Exception as e: + logger.error(f"Failed to update usage: {e}") + import traceback + logger.error(traceback.format_exc()) + return False diff --git a/core/limits_utils.py b/core/limits_utils.py new file mode 100644 index 0000000..8a868a5 --- /dev/null +++ b/core/limits_utils.py @@ -0,0 +1,71 @@ +import logging + +# Initialize logger +logger = logging.getLogger(__name__) + + +async def check_and_increment_limits(auth, limit_type: str, value: int = 1) -> None: + """ + Check if the user is within limits for an operation and increment usage. + + Args: + auth: Authentication context with user_id + limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache) + value: Value to check against limit (e.g., file size for storage_size) + + Raises: + HTTPException: If the user exceeds limits + """ + from fastapi import HTTPException + from core.config import get_settings + + settings = get_settings() + + # Skip limit checking in self-hosted mode + if settings.MODE == "self_hosted": + return + + # Check if user_id is available + if not auth.user_id: + logger.warning("User ID not available in auth context, skipping limit check") + return + + # Initialize user service + from core.services.user_service import UserService + + user_service = UserService() + await user_service.initialize() + + # Check if user is within limits + within_limits = await user_service.check_limit(auth.user_id, limit_type, value) + + if not within_limits: + # Get tier information for better error message + user_data = await user_service.get_user_limits(auth.user_id) + tier = user_data.get("tier", "unknown") if user_data else "unknown" + + # Map limit types to appropriate error messages + limit_type_messages = { + "query": f"Query limit exceeded for your {tier} tier. Please upgrade or try again later.", + "ingest": f"Ingest limit exceeded for your {tier} tier. Please upgrade or try again later.", + "storage_file": f"Storage file count limit exceeded for your {tier} tier. Please delete some files or upgrade.", + "storage_size": f"Storage size limit exceeded for your {tier} tier. Please delete some files or upgrade.", + "graph": f"Graph creation limit exceeded for your {tier} tier. Please upgrade to create more graphs.", + "cache": f"Cache creation limit exceeded for your {tier} tier. Please upgrade to create more caches.", + "cache_query": f"Cache query limit exceeded for your {tier} tier. Please upgrade or try again later.", + } + + # Get message for the limit type or use default message + detail = limit_type_messages.get( + limit_type, f"Limit exceeded for your {tier} tier. Please upgrade or contact support." + ) + + # Raise the exception with appropriate message + raise HTTPException(status_code=429, detail=detail) + + # Record usage asynchronously + try: + await user_service.record_usage(auth.user_id, limit_type, value) + except Exception as e: + # Just log if recording usage fails, don't fail the operation + logger.error(f"Failed to record usage: {e}") diff --git a/core/models/auth.py b/core/models/auth.py index 9ca5a6a..42301fb 100644 --- a/core/models/auth.py +++ b/core/models/auth.py @@ -16,3 +16,4 @@ class AuthContext(BaseModel): app_id: Optional[str] = None # uuid, only for developers # TODO: remove permissions, not required here. permissions: Set[str] = {"read"} + user_id: Optional[str] = None # ID of the user who owns the app/entity diff --git a/core/models/request.py b/core/models/request.py index 2cde0f9..c56fefd 100644 --- a/core/models/request.py +++ b/core/models/request.py @@ -57,3 +57,11 @@ class BatchIngestResponse(BaseModel): """Response model for batch ingestion""" documents: List[Document] errors: List[Dict[str, str]] + + +class GenerateUriRequest(BaseModel): + """Request model for generating a cloud URI""" + app_id: str = Field(..., description="ID of the application") + name: str = Field(..., description="Name of the application") + user_id: str = Field(..., description="ID of the user who owns the app") + expiry_days: int = Field(default=30, description="Number of days until the token expires") diff --git a/core/models/tiers.py b/core/models/tiers.py new file mode 100644 index 0000000..abc05d8 --- /dev/null +++ b/core/models/tiers.py @@ -0,0 +1,117 @@ +from enum import Enum +from typing import Dict, Any + + +class AccountTier(str, Enum): + """Available account tiers.""" + + FREE = "free" + PRO = "pro" + CUSTOM = "custom" + SELF_HOSTED = "self_hosted" + + +# Tier limits definition - organized by API endpoint usage +TIER_LIMITS = { + AccountTier.FREE: { + # Application limits + "app_limit": 1, # Maximum number of applications + # Storage limits + "storage_file_limit": 100, # Maximum number of files in storage + "storage_size_limit_gb": 1, # Maximum storage size in GB + "hourly_ingest_limit": 30, # Maximum file/text ingests per hour + "monthly_ingest_limit": 100, # Maximum file/text ingests per month + # Query limits + "hourly_query_limit": 30, # Maximum queries per hour + "monthly_query_limit": 1000, # Maximum queries per month + # Graph limits + "graph_creation_limit": 3, # Maximum number of graphs + "hourly_graph_query_limit": 20, # Maximum graph queries per hour + "monthly_graph_query_limit": 200, # Maximum graph queries per month + # Cache limits + "cache_creation_limit": 0, # Maximum number of caches + "hourly_cache_query_limit": 0, # Maximum cache queries per hour + "monthly_cache_query_limit": 0, # Maximum cache queries per month + }, + AccountTier.PRO: { + # Application limits + "app_limit": 5, # Maximum number of applications + # Storage limits + "storage_file_limit": 1000, # Maximum number of files in storage + "storage_size_limit_gb": 10, # Maximum storage size in GB + "hourly_ingest_limit": 100, # Maximum file/text ingests per hour + "monthly_ingest_limit": 3000, # Maximum file/text ingests per month + # Query limits + "hourly_query_limit": 100, # Maximum queries per hour + "monthly_query_limit": 10000, # Maximum queries per month + # Graph limits + "graph_creation_limit": 10, # Maximum number of graphs + "hourly_graph_query_limit": 50, # Maximum graph queries per hour + "monthly_graph_query_limit": 1000, # Maximum graph queries per month + # Cache limits + "cache_creation_limit": 5, # Maximum number of caches + "hourly_cache_query_limit": 200, # Maximum cache queries per hour + "monthly_cache_query_limit": 5000, # Maximum cache queries per month + }, + AccountTier.CUSTOM: { + # Custom tier limits are set on a per-account basis + # These are default values that will be overridden + # Application limits + "app_limit": 10, # Maximum number of applications + # Storage limits + "storage_file_limit": 10000, # Maximum number of files in storage + "storage_size_limit_gb": 100, # Maximum storage size in GB + "hourly_ingest_limit": 500, # Maximum file/text ingests per hour + "monthly_ingest_limit": 15000, # Maximum file/text ingests per month + # Query limits + "hourly_query_limit": 500, # Maximum queries per hour + "monthly_query_limit": 50000, # Maximum queries per month + # Graph limits + "graph_creation_limit": 50, # Maximum number of graphs + "hourly_graph_query_limit": 200, # Maximum graph queries per hour + "monthly_graph_query_limit": 10000, # Maximum graph queries per month + # Cache limits + "cache_creation_limit": 20, # Maximum number of caches + "hourly_cache_query_limit": 1000, # Maximum cache queries per hour + "monthly_cache_query_limit": 50000, # Maximum cache queries per month + }, + AccountTier.SELF_HOSTED: { + # Self-hosted has no limits + # Application limits + "app_limit": float("inf"), # Maximum number of applications + # Storage limits + "storage_file_limit": float("inf"), # Maximum number of files in storage + "storage_size_limit_gb": float("inf"), # Maximum storage size in GB + "hourly_ingest_limit": float("inf"), # Maximum file/text ingests per hour + "monthly_ingest_limit": float("inf"), # Maximum file/text ingests per month + # Query limits + "hourly_query_limit": float("inf"), # Maximum queries per hour + "monthly_query_limit": float("inf"), # Maximum queries per month + # Graph limits + "graph_creation_limit": float("inf"), # Maximum number of graphs + "hourly_graph_query_limit": float("inf"), # Maximum graph queries per hour + "monthly_graph_query_limit": float("inf"), # Maximum graph queries per month + # Cache limits + "cache_creation_limit": float("inf"), # Maximum number of caches + "hourly_cache_query_limit": float("inf"), # Maximum cache queries per hour + "monthly_cache_query_limit": float("inf"), # Maximum cache queries per month + }, +} + + +def get_tier_limits(tier: AccountTier, custom_limits: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Get limits for a specific account tier. + + Args: + tier: The account tier + custom_limits: Optional custom limits for CUSTOM tier + + Returns: + Dict of limits for the specified tier + """ + if tier == AccountTier.CUSTOM and custom_limits: + # Merge default custom limits with the provided custom limits + return {**TIER_LIMITS[tier], **custom_limits} + + return TIER_LIMITS[tier] diff --git a/core/models/user_limits.py b/core/models/user_limits.py new file mode 100644 index 0000000..20a6bc6 --- /dev/null +++ b/core/models/user_limits.py @@ -0,0 +1,55 @@ +from typing import Dict, Any, Optional +from datetime import datetime, UTC +from pydantic import BaseModel, Field + +from .tiers import AccountTier + + +class UserUsage(BaseModel): + """Tracks user's actual usage of the system.""" + + # Storage usage + storage_file_count: int = 0 + storage_size_bytes: int = 0 + + # Query usage - hourly + hourly_query_count: int = 0 + hourly_query_reset: Optional[datetime] = None + + # Query usage - monthly + monthly_query_count: int = 0 + monthly_query_reset: Optional[datetime] = None + + # Ingest usage - hourly + hourly_ingest_count: int = 0 + hourly_ingest_reset: Optional[datetime] = None + + # Ingest usage - monthly + monthly_ingest_count: int = 0 + monthly_ingest_reset: Optional[datetime] = None + + # Graph usage + graph_count: int = 0 + hourly_graph_query_count: int = 0 + hourly_graph_query_reset: Optional[datetime] = None + monthly_graph_query_count: int = 0 + monthly_graph_query_reset: Optional[datetime] = None + + # Cache usage + cache_count: int = 0 + hourly_cache_query_count: int = 0 + hourly_cache_query_reset: Optional[datetime] = None + monthly_cache_query_count: int = 0 + monthly_cache_query_reset: Optional[datetime] = None + + +class UserLimits(BaseModel): + """Stores user tier and usage information.""" + + user_id: str + tier: AccountTier = AccountTier.FREE + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + usage: UserUsage = Field(default_factory=UserUsage) + custom_limits: Optional[Dict[str, Any]] = None + app_ids: list[str] = Field(default_factory=list) diff --git a/core/services/document_service.py b/core/services/document_service.py index 5e52e9c..891349b 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -329,6 +329,15 @@ class DocumentService: logger.error(f"User {auth.entity_id} does not have write permission") raise PermissionError("User does not have write permission") + # First check ingest limits if in cloud mode + from core.config import get_settings + settings = get_settings() + + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding + from core.api import check_and_increment_limits + await check_and_increment_limits(auth, "ingest", 1) + doc = Document( content_type="text/plain", filename=filename, @@ -338,6 +347,7 @@ class DocumentService: "readers": [auth.entity_id], "writers": [auth.entity_id], "admins": [auth.entity_id], + "user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list) }, ) logger.info(f"Created text document record with ID {doc.external_id}") @@ -404,6 +414,20 @@ class DocumentService: # Read file content file_content = await file.read() + file_size = len(file_content) # Get file size in bytes for limit checking + + # Check limits before doing any expensive processing + from core.config import get_settings + settings = get_settings() + + if settings.MODE == "cloud" and auth.user_id: + # Check limits before proceeding with parsing + from core.api import check_and_increment_limits + await check_and_increment_limits(auth, "ingest", 1) + await check_and_increment_limits(auth, "storage_file", 1) + await check_and_increment_limits(auth, "storage_size", file_size) + + # Now proceed with parsing and processing the file file_type = filetype.guess(file_content) # Set default mime type for cases where filetype.guess returns None @@ -438,8 +462,7 @@ class DocumentService: if modified_text: text = modified_text logger.info("Updated text with modified content from rules") - - # Create document record + doc = Document( content_type=mime_type, filename=file.filename, @@ -449,6 +472,7 @@ class DocumentService: "readers": [auth.entity_id], "writers": [auth.entity_id], "admins": [auth.entity_id], + "user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list) }, additional_metadata=additional_metadata, ) diff --git a/core/services/user_service.py b/core/services/user_service.py new file mode 100644 index 0000000..257e5ac --- /dev/null +++ b/core/services/user_service.py @@ -0,0 +1,236 @@ +import logging +from typing import Dict, Any, Optional +from datetime import datetime, UTC, timedelta +import jwt + +from ..models.tiers import AccountTier, get_tier_limits +from ..models.auth import AuthContext +from ..config import get_settings +from ..database.user_limits_db import UserLimitsDatabase + +logger = logging.getLogger(__name__) + + +class UserService: + """Service for managing user limits and usage.""" + + def __init__(self): + """Initialize the UserService.""" + self.settings = get_settings() + self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI) + + async def initialize(self) -> bool: + """Initialize database tables.""" + return await self.db.initialize() + + async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]: + """Get user limits information.""" + return await self.db.get_user_limits(user_id) + + async def create_user(self, user_id: str) -> bool: + """Create a new user with FREE tier.""" + return await self.db.create_user_limits(user_id, tier=AccountTier.FREE) + + async def update_user_tier( + self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None + ) -> bool: + """Update user tier and custom limits.""" + return await self.db.update_user_tier(user_id, tier, custom_limits) + + async def register_app(self, user_id: str, app_id: str) -> bool: + """ + Register an app for a user. + + Creates user limits record if it doesn't exist. + """ + # First check if user limits exist + user_limits = await self.db.get_user_limits(user_id) + + # If user limits don't exist, create them first + if not user_limits: + logger.info(f"Creating user limits for user {user_id}") + success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE) + if not success: + logger.error(f"Failed to create user limits for user {user_id}") + return False + + # Now register the app + return await self.db.register_app(user_id, app_id) + + async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool: + """ + Check if a user's operation is within limits when given value is considered. + + Args: + user_id: The user ID to check + limit_type: Type of limit (query, ingest, graph, cache, etc.) + value: Value to check (e.g., file size for storage) + + Returns: + True if within limits, False if exceeded + """ + # Skip limit checking for self-hosted mode + if self.settings.MODE == "self_hosted": + return True + + # Get user limits + user_data = await self.db.get_user_limits(user_id) + if not user_data: + # Create user limits if they don't exist + logger.info(f"User {user_id} not found when checking limits - creating limits record") + success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE) + if not success: + logger.error(f"Failed to create user limits for user {user_id}") + return False + # Fetch the newly created limits + user_data = await self.db.get_user_limits(user_id) + if not user_data: + logger.error(f"Failed to retrieve newly created user limits for user {user_id}") + return False + + # Get tier limits + tier = user_data.get("tier", AccountTier.FREE) + tier_limits = get_tier_limits(tier, user_data.get("custom_limits")) + + # Get current usage + usage = user_data.get("usage", {}) + + # Check specific limit type + if limit_type == "query": + hourly_limit = tier_limits.get("hourly_query_limit", 0) + monthly_limit = tier_limits.get("monthly_query_limit", 0) + + hourly_usage = usage.get("hourly_query_count", 0) + monthly_usage = usage.get("monthly_query_count", 0) + + return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit + + elif limit_type == "ingest": + hourly_limit = tier_limits.get("hourly_ingest_limit", 0) + monthly_limit = tier_limits.get("monthly_ingest_limit", 0) + + hourly_usage = usage.get("hourly_ingest_count", 0) + monthly_usage = usage.get("monthly_ingest_count", 0) + + return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit + + elif limit_type == "storage_file": + file_limit = tier_limits.get("storage_file_limit", 0) + file_count = usage.get("storage_file_count", 0) + + return file_count + value <= file_limit + + elif limit_type == "storage_size": + size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024 + size_usage = usage.get("storage_size_bytes", 0) + + return size_usage + value <= size_limit_bytes + + elif limit_type == "graph": + graph_limit = tier_limits.get("graph_creation_limit", 0) + graph_count = usage.get("graph_count", 0) + + return graph_count + value <= graph_limit + + elif limit_type == "cache": + cache_limit = tier_limits.get("cache_creation_limit", 0) + cache_count = usage.get("cache_count", 0) + + return cache_count + value <= cache_limit + + return True + + async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool: + """ + Record usage for a user. + + Args: + user_id: The user ID + usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.) + increment: Value to increment by + + Returns: + True if successful, False otherwise + """ + # Skip usage recording for self-hosted mode + if self.settings.MODE == "self_hosted": + return True + + # Check if user limits exist, create if they don't + user_data = await self.db.get_user_limits(user_id) + if not user_data: + logger.info(f"Creating user limits for user {user_id} during usage recording") + success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE) + if not success: + logger.error(f"Failed to create user limits for user {user_id}") + return False + + return await self.db.update_usage(user_id, usage_type, increment) + + async def generate_cloud_uri( + self, user_id: str, app_id: str, name: str, expiry_days: int = 30 + ) -> Optional[str]: + """ + Generate a cloud URI for an app. + + Args: + user_id: The user ID + app_id: The app ID + name: App name for display purposes + expiry_days: Number of days until token expires + + Returns: + URI string with embedded token, or None if failed + """ + # Get user limits to check app limit + user_limits = await self.db.get_user_limits(user_id) + + # If user doesn't exist yet, create them + if not user_limits: + await self.create_user(user_id) + user_limits = await self.db.get_user_limits(user_id) + if not user_limits: + logger.error(f"Failed to create user limits for user {user_id}") + return None + + # Get tier limits to enforce app limit + tier = user_limits.get("tier", AccountTier.FREE) + tier_limits = get_tier_limits(tier, user_limits.get("custom_limits")) + app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified + + current_apps = user_limits.get("app_ids", []) + + # Skip the limit check if app is already registered + if app_id not in current_apps: + # Check if user has reached app limit + if len(current_apps) >= app_limit: + logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}") + return None + + # Register the app + success = await self.register_app(user_id, app_id) + if not success: + logger.info(f"Failed to register app {app_id} for user {user_id}") + return None + + # Create token payload + payload = { + "user_id": user_id, + "app_id": app_id, + "name": name, + "permissions": ["read", "write"], + "exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()), + "type": "developer", + "entity_id": user_id, + } + + # Generate token + token = jwt.encode( + payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM + ) + + # Generate URI with API domain + api_domain = getattr(self.settings, "API_DOMAIN", "api.databridge.ai") + uri = f"databridge://{name}:{token}@{api_domain}/{app_id}" + + return uri diff --git a/databridge.toml b/databridge.toml index 919f171..13ef956 100644 --- a/databridge.toml +++ b/databridge.toml @@ -100,6 +100,7 @@ batch_size = 4096 [databridge] enable_colpali = true +mode = "self_hosted" # "cloud" or "self_hosted" [graph] provider = "ollama"