mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
237 lines
8.8 KiB
Python
237 lines
8.8 KiB
Python
import logging
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime, UTC, timedelta
|
|
import jwt
|
|
|
|
from ..models.tiers import AccountTier, get_tier_limits
|
|
from ..models.auth import AuthContext
|
|
from ..config import get_settings
|
|
from ..database.user_limits_db import UserLimitsDatabase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UserService:
|
|
"""Service for managing user limits and usage."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the UserService."""
|
|
self.settings = get_settings()
|
|
self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI)
|
|
|
|
async def initialize(self) -> bool:
|
|
"""Initialize database tables."""
|
|
return await self.db.initialize()
|
|
|
|
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get user limits information."""
|
|
return await self.db.get_user_limits(user_id)
|
|
|
|
async def create_user(self, user_id: str) -> bool:
|
|
"""Create a new user with FREE tier."""
|
|
return await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
|
|
|
async def update_user_tier(
|
|
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
|
|
) -> bool:
|
|
"""Update user tier and custom limits."""
|
|
return await self.db.update_user_tier(user_id, tier, custom_limits)
|
|
|
|
async def register_app(self, user_id: str, app_id: str) -> bool:
|
|
"""
|
|
Register an app for a user.
|
|
|
|
Creates user limits record if it doesn't exist.
|
|
"""
|
|
# First check if user limits exist
|
|
user_limits = await self.db.get_user_limits(user_id)
|
|
|
|
# If user limits don't exist, create them first
|
|
if not user_limits:
|
|
logger.info(f"Creating user limits for user {user_id}")
|
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
|
if not success:
|
|
logger.error(f"Failed to create user limits for user {user_id}")
|
|
return False
|
|
|
|
# Now register the app
|
|
return await self.db.register_app(user_id, app_id)
|
|
|
|
async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool:
|
|
"""
|
|
Check if a user's operation is within limits when given value is considered.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
limit_type: Type of limit (query, ingest, graph, cache, etc.)
|
|
value: Value to check (e.g., file size for storage)
|
|
|
|
Returns:
|
|
True if within limits, False if exceeded
|
|
"""
|
|
# Skip limit checking for self-hosted mode
|
|
if self.settings.MODE == "self_hosted":
|
|
return True
|
|
|
|
# Get user limits
|
|
user_data = await self.db.get_user_limits(user_id)
|
|
if not user_data:
|
|
# Create user limits if they don't exist
|
|
logger.info(f"User {user_id} not found when checking limits - creating limits record")
|
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
|
if not success:
|
|
logger.error(f"Failed to create user limits for user {user_id}")
|
|
return False
|
|
# Fetch the newly created limits
|
|
user_data = await self.db.get_user_limits(user_id)
|
|
if not user_data:
|
|
logger.error(f"Failed to retrieve newly created user limits for user {user_id}")
|
|
return False
|
|
|
|
# Get tier limits
|
|
tier = user_data.get("tier", AccountTier.FREE)
|
|
tier_limits = get_tier_limits(tier, user_data.get("custom_limits"))
|
|
|
|
# Get current usage
|
|
usage = user_data.get("usage", {})
|
|
|
|
# Check specific limit type
|
|
if limit_type == "query":
|
|
hourly_limit = tier_limits.get("hourly_query_limit", 0)
|
|
monthly_limit = tier_limits.get("monthly_query_limit", 0)
|
|
|
|
hourly_usage = usage.get("hourly_query_count", 0)
|
|
monthly_usage = usage.get("monthly_query_count", 0)
|
|
|
|
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
|
|
|
elif limit_type == "ingest":
|
|
hourly_limit = tier_limits.get("hourly_ingest_limit", 0)
|
|
monthly_limit = tier_limits.get("monthly_ingest_limit", 0)
|
|
|
|
hourly_usage = usage.get("hourly_ingest_count", 0)
|
|
monthly_usage = usage.get("monthly_ingest_count", 0)
|
|
|
|
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
|
|
|
elif limit_type == "storage_file":
|
|
file_limit = tier_limits.get("storage_file_limit", 0)
|
|
file_count = usage.get("storage_file_count", 0)
|
|
|
|
return file_count + value <= file_limit
|
|
|
|
elif limit_type == "storage_size":
|
|
size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024
|
|
size_usage = usage.get("storage_size_bytes", 0)
|
|
|
|
return size_usage + value <= size_limit_bytes
|
|
|
|
elif limit_type == "graph":
|
|
graph_limit = tier_limits.get("graph_creation_limit", 0)
|
|
graph_count = usage.get("graph_count", 0)
|
|
|
|
return graph_count + value <= graph_limit
|
|
|
|
elif limit_type == "cache":
|
|
cache_limit = tier_limits.get("cache_creation_limit", 0)
|
|
cache_count = usage.get("cache_count", 0)
|
|
|
|
return cache_count + value <= cache_limit
|
|
|
|
return True
|
|
|
|
async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
|
|
"""
|
|
Record usage for a user.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.)
|
|
increment: Value to increment by
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
# Skip usage recording for self-hosted mode
|
|
if self.settings.MODE == "self_hosted":
|
|
return True
|
|
|
|
# Check if user limits exist, create if they don't
|
|
user_data = await self.db.get_user_limits(user_id)
|
|
if not user_data:
|
|
logger.info(f"Creating user limits for user {user_id} during usage recording")
|
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
|
if not success:
|
|
logger.error(f"Failed to create user limits for user {user_id}")
|
|
return False
|
|
|
|
return await self.db.update_usage(user_id, usage_type, increment)
|
|
|
|
async def generate_cloud_uri(
|
|
self, user_id: str, app_id: str, name: str, expiry_days: int = 30
|
|
) -> Optional[str]:
|
|
"""
|
|
Generate a cloud URI for an app.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
app_id: The app ID
|
|
name: App name for display purposes
|
|
expiry_days: Number of days until token expires
|
|
|
|
Returns:
|
|
URI string with embedded token, or None if failed
|
|
"""
|
|
# Get user limits to check app limit
|
|
user_limits = await self.db.get_user_limits(user_id)
|
|
|
|
# If user doesn't exist yet, create them
|
|
if not user_limits:
|
|
await self.create_user(user_id)
|
|
user_limits = await self.db.get_user_limits(user_id)
|
|
if not user_limits:
|
|
logger.error(f"Failed to create user limits for user {user_id}")
|
|
return None
|
|
|
|
# Get tier limits to enforce app limit
|
|
tier = user_limits.get("tier", AccountTier.FREE)
|
|
tier_limits = get_tier_limits(tier, user_limits.get("custom_limits"))
|
|
app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified
|
|
|
|
current_apps = user_limits.get("app_ids", [])
|
|
|
|
# Skip the limit check if app is already registered
|
|
if app_id not in current_apps:
|
|
# Check if user has reached app limit
|
|
if len(current_apps) >= app_limit:
|
|
logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}")
|
|
return None
|
|
|
|
# Register the app
|
|
success = await self.register_app(user_id, app_id)
|
|
if not success:
|
|
logger.info(f"Failed to register app {app_id} for user {user_id}")
|
|
return None
|
|
|
|
# Create token payload
|
|
payload = {
|
|
"user_id": user_id,
|
|
"app_id": app_id,
|
|
"name": name,
|
|
"permissions": ["read", "write"],
|
|
"exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()),
|
|
"type": "developer",
|
|
"entity_id": user_id,
|
|
}
|
|
|
|
# Generate token
|
|
token = jwt.encode(
|
|
payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM
|
|
)
|
|
|
|
# Generate URI with API domain
|
|
api_domain = getattr(self.settings, "API_DOMAIN", "api.morphik.ai")
|
|
uri = f"morphik://{name}:{token}@{api_domain}"
|
|
|
|
return uri
|