mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
101 lines
4.1 KiB
Python
101 lines
4.1 KiB
Python
import logging
|
|
|
|
# Initialize logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def check_and_increment_limits(auth, limit_type: str, value: int = 1, document_id: str = None) -> None:
|
|
"""
|
|
Check if the user is within limits for an operation and increment usage.
|
|
Limits are only applied to the free tier; other tiers have no limits.
|
|
For non-free tiers, usage is tracked and metered through Stripe for billing purposes.
|
|
|
|
Args:
|
|
auth: Authentication context with user_id
|
|
limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache)
|
|
value: Value to check against limit (e.g., file size for storage_size)
|
|
document_id: Optional document ID for tracking in Stripe (used for ingest operations)
|
|
|
|
Raises:
|
|
HTTPException: If the user exceeds limits (free tier only)
|
|
"""
|
|
from fastapi import HTTPException
|
|
|
|
from core.config import get_settings
|
|
from core.models.tiers import AccountTier
|
|
|
|
settings = get_settings()
|
|
|
|
# Skip limit checking in self-hosted mode
|
|
if settings.MODE == "self_hosted":
|
|
return
|
|
|
|
# Check if user_id is available
|
|
if not auth.user_id:
|
|
logger.warning("User ID not available in auth context, skipping limit check")
|
|
return
|
|
|
|
# Initialize user service
|
|
from core.services.user_service import UserService
|
|
|
|
user_service = UserService()
|
|
await user_service.initialize()
|
|
|
|
# Get user data to check tier
|
|
user_data = await user_service.get_user_limits(auth.user_id)
|
|
if not user_data:
|
|
# Create user limits if they don't exist (defaults to free tier)
|
|
await user_service.create_user(auth.user_id)
|
|
user_data = await user_service.get_user_limits(auth.user_id)
|
|
if not user_data:
|
|
logger.error(f"Failed to create user limits for user {auth.user_id}")
|
|
return
|
|
|
|
tier = user_data.get("tier", AccountTier.FREE)
|
|
|
|
# Only apply limits to free tier users
|
|
if tier != AccountTier.FREE:
|
|
# For non-free tiers, just record usage without checking limits
|
|
try:
|
|
await user_service.record_usage(auth.user_id, limit_type, value, document_id)
|
|
except Exception as e:
|
|
logger.error(f"Failed to record usage: {e}")
|
|
return
|
|
|
|
# For free tier, check if user is within limits
|
|
within_limits = await user_service.check_limit(auth.user_id, limit_type, value)
|
|
|
|
if not within_limits:
|
|
# Map limit types to appropriate error messages
|
|
storage_message = (
|
|
"Storage file count limit exceeded for your free tier. "
|
|
"Please delete some files or upgrade to remove limits."
|
|
)
|
|
limit_type_messages = {
|
|
"query": "Query limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
"ingest": "Ingest limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
"storage_file": storage_message,
|
|
"storage_size": (
|
|
"Storage size limit exceeded for your free tier. "
|
|
"Please delete some files or upgrade to remove limits."
|
|
),
|
|
"graph": "Graph creation limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
"cache": "Cache creation limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
"cache_query": "Cache query limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
"agent": "Agent call limit exceeded for your free tier. Please upgrade to remove limits.",
|
|
}
|
|
|
|
# Get message for the limit type or use default message
|
|
default_message = "Limit exceeded for your free tier. Please upgrade to remove limits."
|
|
detail = limit_type_messages.get(limit_type, default_message)
|
|
|
|
# Raise the exception with appropriate message
|
|
raise HTTPException(status_code=429, detail=detail)
|
|
|
|
# Record usage asynchronously
|
|
try:
|
|
await user_service.record_usage(auth.user_id, limit_type, value, document_id)
|
|
except Exception as e:
|
|
# Just log if recording usage fails, don't fail the operation
|
|
logger.error(f"Failed to record usage: {e}")
|