mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add hosted tier limits, cloud uri gen (#59)
This commit is contained in:
parent
f0c44cb8ea
commit
7eb5887d2f
177
core/api.py
177
core/api.py
@ -11,7 +11,8 @@ import logging
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from core.completion.openai_completion import OpenAICompletionModel
|
||||
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
|
||||
from core.limits_utils import check_and_increment_limits
|
||||
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
|
||||
from core.models.completion import ChunkSource, CompletionResponse
|
||||
from core.models.documents import Document, DocumentResult, ChunkResult
|
||||
from core.models.graph import Graph
|
||||
@ -236,6 +237,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
||||
entity_type=EntityType(settings.dev_entity_type),
|
||||
entity_id=settings.dev_entity_id,
|
||||
permissions=set(settings.dev_permissions),
|
||||
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
|
||||
)
|
||||
|
||||
# Normal token verification flow
|
||||
@ -255,11 +257,17 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
||||
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
|
||||
# Support both "type" and "entity_type" fields for compatibility
|
||||
entity_type_field = payload.get("type") or payload.get("entity_type")
|
||||
if not entity_type_field:
|
||||
raise HTTPException(status_code=401, detail="Missing entity type in token")
|
||||
|
||||
return AuthContext(
|
||||
entity_type=EntityType(payload["type"]),
|
||||
entity_type=EntityType(entity_type_field),
|
||||
entity_id=payload["entity_id"],
|
||||
app_id=payload.get("app_id"),
|
||||
permissions=set(payload.get("permissions", ["read"])),
|
||||
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
@ -569,6 +577,11 @@ async def query_completion(
|
||||
to enhance retrieval by finding relevant entities and their connected documents.
|
||||
"""
|
||||
try:
|
||||
# Check query limits if in cloud mode
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding
|
||||
await check_and_increment_limits(auth, "query", 1)
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="query",
|
||||
user_id=auth.entity_id,
|
||||
@ -689,6 +702,7 @@ async def update_document_text(
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/documents/{document_id}/update_file", response_model=Document)
|
||||
async def update_document_file(
|
||||
document_id: str,
|
||||
@ -856,6 +870,11 @@ async def create_cache(
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new cache with specified configuration."""
|
||||
try:
|
||||
# Check cache creation limits if in cloud mode
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding
|
||||
await check_and_increment_limits(auth, "cache", 1)
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="create_cache",
|
||||
user_id=auth.entity_id,
|
||||
@ -955,6 +974,11 @@ async def query_cache(
|
||||
) -> CompletionResponse:
|
||||
"""Query the cache with a prompt."""
|
||||
try:
|
||||
# Check cache query limits if in cloud mode
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding
|
||||
await check_and_increment_limits(auth, "cache_query", 1)
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="query_cache",
|
||||
user_id=auth.entity_id,
|
||||
@ -994,6 +1018,11 @@ async def create_graph(
|
||||
Graph: The created graph object
|
||||
"""
|
||||
try:
|
||||
# Check graph creation limits if in cloud mode
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding
|
||||
await check_and_increment_limits(auth, "graph", 1)
|
||||
|
||||
async with telemetry.track_operation(
|
||||
operation_type="create_graph",
|
||||
user_id=auth.entity_id,
|
||||
@ -1109,3 +1138,147 @@ async def generate_local_uri(
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating local URI: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/cloud/generate_uri", include_in_schema=True)
|
||||
async def generate_cloud_uri(
|
||||
request: GenerateUriRequest,
|
||||
authorization: str = Header(None),
|
||||
) -> Dict[str, str]:
|
||||
"""Generate a URI for cloud hosted applications."""
|
||||
try:
|
||||
app_id = request.app_id
|
||||
name = request.name
|
||||
user_id = request.user_id
|
||||
expiry_days = request.expiry_days
|
||||
|
||||
logger.info(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
|
||||
|
||||
# Verify authorization header before proceeding
|
||||
if not authorization:
|
||||
logger.warning("Missing authorization header")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization header",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify the token is valid
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = authorization[7:] # Remove "Bearer "
|
||||
|
||||
try:
|
||||
# Decode the token to ensure it's valid
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# Only allow users to create apps for themselves (or admin)
|
||||
token_user_id = payload.get("user_id")
|
||||
logger.debug(f"Token user ID: {token_user_id}")
|
||||
logger.debug(f"User ID: {user_id}")
|
||||
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You can only create apps for your own account unless you have admin permissions"
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
# Import UserService here to avoid circular imports
|
||||
from core.services.user_service import UserService
|
||||
user_service = UserService()
|
||||
|
||||
# Initialize user service if needed
|
||||
await user_service.initialize()
|
||||
|
||||
# Clean name
|
||||
name = name.replace(" ", "_").lower()
|
||||
|
||||
# Check if the user is within app limit and generate URI
|
||||
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
|
||||
|
||||
if not uri:
|
||||
logger.info("Application limit reached for this account tier with user_id: %s", user_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Application limit reached for this account tier"
|
||||
)
|
||||
|
||||
return {"uri": uri, "app_id": app_id}
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating cloud URI: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/user/upgrade", include_in_schema=True)
|
||||
async def upgrade_user_tier(
|
||||
user_id: str,
|
||||
tier: str,
|
||||
custom_limits: Optional[Dict[str, Any]] = None,
|
||||
authorization: str = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""Upgrade a user to a higher tier."""
|
||||
try:
|
||||
# Verify admin authorization
|
||||
if not authorization:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization header",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = authorization[7:] # Remove "Bearer "
|
||||
|
||||
try:
|
||||
# Decode token
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
|
||||
# Only allow admins to upgrade users
|
||||
if "admin" not in payload.get("permissions", []):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Admin permission required"
|
||||
)
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
# Validate tier
|
||||
from core.models.tiers import AccountTier
|
||||
try:
|
||||
account_tier = AccountTier(tier)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}")
|
||||
|
||||
# Upgrade user
|
||||
from core.services.user_service import UserService
|
||||
user_service = UserService()
|
||||
|
||||
# Initialize user service
|
||||
await user_service.initialize()
|
||||
|
||||
# Update user tier
|
||||
success = await user_service.update_user_tier(user_id, tier, custom_limits)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="User not found or upgrade failed"
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"tier": tier,
|
||||
"message": f"User upgraded to {tier} tier"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error upgrading user tier: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -91,6 +91,9 @@ class Settings(BaseSettings):
|
||||
# Colpali configuration
|
||||
ENABLE_COLPALI: bool
|
||||
|
||||
# Mode configuration
|
||||
MODE: Literal["cloud", "self_hosted"] = "cloud"
|
||||
|
||||
# Telemetry configuration
|
||||
TELEMETRY_ENABLED: bool = True
|
||||
HONEYCOMB_ENABLED: bool = True
|
||||
@ -291,6 +294,7 @@ def get_settings() -> Settings:
|
||||
# load databridge config
|
||||
databridge_config = {
|
||||
"ENABLE_COLPALI": config["databridge"]["enable_colpali"],
|
||||
"MODE": config["databridge"].get("mode", "cloud"), # Default to "cloud" mode
|
||||
}
|
||||
|
||||
# load graph config
|
||||
|
@ -494,6 +494,15 @@ class PostgresDatabase(BaseDatabase):
|
||||
if auth.entity_type == "DEVELOPER" and auth.app_id:
|
||||
# Add app-specific access for developers
|
||||
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
|
||||
|
||||
# Add user_id filter in cloud mode
|
||||
if auth.user_id:
|
||||
from core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.MODE == "cloud":
|
||||
# Filter by user_id in access_control
|
||||
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
|
||||
|
||||
return " OR ".join(filters)
|
||||
|
||||
|
359
core/database/user_limits_db.py
Normal file
359
core/database/user_limits_db.py
Normal file
@ -0,0 +1,359 @@
|
||||
import json
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, UTC, timedelta
|
||||
import logging
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy import Column, String, Index, select, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class UserLimitsModel(Base):
|
||||
"""SQLAlchemy model for user limits data."""
|
||||
|
||||
__tablename__ = "user_limits"
|
||||
|
||||
user_id = Column(String, primary_key=True)
|
||||
tier = Column(String, nullable=False) # FREE, PRO, CUSTOM, SELF_HOSTED
|
||||
custom_limits = Column(JSONB, nullable=True)
|
||||
usage = Column(JSONB, default=dict) # Holds all usage counters
|
||||
app_ids = Column(JSONB, default=list) # List of app IDs registered by this user
|
||||
created_at = Column(String) # ISO format string
|
||||
updated_at = Column(String) # ISO format string
|
||||
|
||||
# Create indexes
|
||||
__table_args__ = (Index("idx_user_tier", "tier"),)
|
||||
|
||||
|
||||
class UserLimitsDatabase:
|
||||
"""Database operations for user limits."""
|
||||
|
||||
def __init__(self, uri: str):
|
||||
"""Initialize database connection."""
|
||||
self.engine = create_async_engine(uri)
|
||||
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize database tables and indexes."""
|
||||
if self._initialized:
|
||||
return True
|
||||
|
||||
try:
|
||||
logger.info("Initializing user limits database tables...")
|
||||
# Create tables if they don't exist
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("User limits database tables initialized successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize user limits database: {e}")
|
||||
return False
|
||||
|
||||
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get user limits for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to get limits for
|
||||
|
||||
Returns:
|
||||
Dict with user limits if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||
)
|
||||
user_limits = result.scalars().first()
|
||||
|
||||
if not user_limits:
|
||||
return None
|
||||
|
||||
return {
|
||||
"user_id": user_limits.user_id,
|
||||
"tier": user_limits.tier,
|
||||
"custom_limits": user_limits.custom_limits,
|
||||
"usage": user_limits.usage,
|
||||
"app_ids": user_limits.app_ids,
|
||||
"created_at": user_limits.created_at,
|
||||
"updated_at": user_limits.updated_at,
|
||||
}
|
||||
|
||||
async def create_user_limits(self, user_id: str, tier: str = "free") -> bool:
|
||||
"""
|
||||
Create user limits record.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
tier: Initial tier (defaults to "free")
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Check if already exists
|
||||
result = await session.execute(
|
||||
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||
)
|
||||
if result.scalars().first():
|
||||
return True # Already exists
|
||||
|
||||
# Create new record with properly initialized JSONB columns
|
||||
# Create JSON strings and parse them for consistency
|
||||
usage_json = json.dumps({
|
||||
"storage_file_count": 0,
|
||||
"storage_size_bytes": 0,
|
||||
"hourly_query_count": 0,
|
||||
"hourly_query_reset": now,
|
||||
"monthly_query_count": 0,
|
||||
"monthly_query_reset": now,
|
||||
"hourly_ingest_count": 0,
|
||||
"hourly_ingest_reset": now,
|
||||
"monthly_ingest_count": 0,
|
||||
"monthly_ingest_reset": now,
|
||||
"graph_count": 0,
|
||||
"cache_count": 0,
|
||||
})
|
||||
app_ids_json = json.dumps([]) # Empty array but as JSON string
|
||||
|
||||
# Create the model with the JSON parsed
|
||||
user_limits = UserLimitsModel(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
usage=json.loads(usage_json),
|
||||
app_ids=json.loads(app_ids_json),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
session.add(user_limits)
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create user limits: {e}")
|
||||
return False
|
||||
|
||||
async def update_user_tier(
|
||||
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Update user tier and custom limits.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
tier: New tier
|
||||
custom_limits: Optional custom limits for CUSTOM tier
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||
)
|
||||
user_limits = result.scalars().first()
|
||||
|
||||
if not user_limits:
|
||||
return False
|
||||
|
||||
user_limits.tier = tier
|
||||
user_limits.custom_limits = custom_limits
|
||||
user_limits.updated_at = now
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update user tier: {e}")
|
||||
return False
|
||||
|
||||
async def register_app(self, user_id: str, app_id: str) -> bool:
|
||||
"""
|
||||
Register an app for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
app_id: The app ID to register
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
async with self.async_session() as session:
|
||||
# First check if user exists
|
||||
result = await session.execute(
|
||||
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||
)
|
||||
user_limits = result.scalars().first()
|
||||
|
||||
if not user_limits:
|
||||
logger.error(f"User {user_id} not found in register_app")
|
||||
return False
|
||||
|
||||
# Use raw SQL with jsonb_array_append to update the app_ids array
|
||||
# This is the most reliable way to append to a JSONB array in PostgreSQL
|
||||
query = text(
|
||||
"""
|
||||
UPDATE user_limits
|
||||
SET
|
||||
app_ids = CASE
|
||||
WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array
|
||||
THEN app_ids || :app_id_json -- Append it if not present
|
||||
ELSE app_ids -- Keep it unchanged if already present
|
||||
END,
|
||||
updated_at = :now
|
||||
WHERE user_id = :user_id
|
||||
RETURNING app_ids;
|
||||
"""
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await session.execute(
|
||||
query,
|
||||
{
|
||||
"app_id": app_id, # For the check
|
||||
"app_id_json": f'["{app_id}"]', # JSON array format for appending
|
||||
"now": now,
|
||||
"user_id": user_id
|
||||
}
|
||||
)
|
||||
|
||||
# Log the result for debugging
|
||||
updated_app_ids = result.scalar()
|
||||
logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}")
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register app: {e}")
|
||||
return False
|
||||
|
||||
async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
|
||||
"""
|
||||
Update usage counter for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
usage_type: Type of usage to update
|
||||
increment: Value to increment by
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
now_iso = now.isoformat()
|
||||
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||
)
|
||||
user_limits = result.scalars().first()
|
||||
|
||||
if not user_limits:
|
||||
return False
|
||||
|
||||
# Create a new dictionary to force SQLAlchemy to detect the change
|
||||
usage = dict(user_limits.usage) if user_limits.usage else {}
|
||||
|
||||
# Handle different usage types
|
||||
if usage_type == "query":
|
||||
# Check hourly reset
|
||||
hourly_reset_str = usage.get("hourly_query_reset", "")
|
||||
if hourly_reset_str:
|
||||
hourly_reset = datetime.fromisoformat(hourly_reset_str)
|
||||
if now > hourly_reset + timedelta(hours=1):
|
||||
usage["hourly_query_count"] = increment
|
||||
usage["hourly_query_reset"] = now_iso
|
||||
else:
|
||||
usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment
|
||||
else:
|
||||
usage["hourly_query_count"] = increment
|
||||
usage["hourly_query_reset"] = now_iso
|
||||
|
||||
# Check monthly reset
|
||||
monthly_reset_str = usage.get("monthly_query_reset", "")
|
||||
if monthly_reset_str:
|
||||
monthly_reset = datetime.fromisoformat(monthly_reset_str)
|
||||
if now > monthly_reset + timedelta(days=30):
|
||||
usage["monthly_query_count"] = increment
|
||||
usage["monthly_query_reset"] = now_iso
|
||||
else:
|
||||
usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment
|
||||
else:
|
||||
usage["monthly_query_count"] = increment
|
||||
usage["monthly_query_reset"] = now_iso
|
||||
|
||||
elif usage_type == "ingest":
|
||||
# Similar pattern for ingest
|
||||
hourly_reset_str = usage.get("hourly_ingest_reset", "")
|
||||
if hourly_reset_str:
|
||||
hourly_reset = datetime.fromisoformat(hourly_reset_str)
|
||||
if now > hourly_reset + timedelta(hours=1):
|
||||
usage["hourly_ingest_count"] = increment
|
||||
usage["hourly_ingest_reset"] = now_iso
|
||||
else:
|
||||
usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment
|
||||
else:
|
||||
usage["hourly_ingest_count"] = increment
|
||||
usage["hourly_ingest_reset"] = now_iso
|
||||
|
||||
monthly_reset_str = usage.get("monthly_ingest_reset", "")
|
||||
if monthly_reset_str:
|
||||
monthly_reset = datetime.fromisoformat(monthly_reset_str)
|
||||
if now > monthly_reset + timedelta(days=30):
|
||||
usage["monthly_ingest_count"] = increment
|
||||
usage["monthly_ingest_reset"] = now_iso
|
||||
else:
|
||||
usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment
|
||||
else:
|
||||
usage["monthly_ingest_count"] = increment
|
||||
usage["monthly_ingest_reset"] = now_iso
|
||||
|
||||
elif usage_type == "storage_file":
|
||||
usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment
|
||||
|
||||
elif usage_type == "storage_size":
|
||||
usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment
|
||||
|
||||
elif usage_type == "graph":
|
||||
usage["graph_count"] = usage.get("graph_count", 0) + increment
|
||||
|
||||
elif usage_type == "cache":
|
||||
usage["cache_count"] = usage.get("cache_count", 0) + increment
|
||||
|
||||
# Force SQLAlchemy to recognize the change by assigning a new dict
|
||||
user_limits.usage = usage
|
||||
user_limits.updated_at = now_iso
|
||||
|
||||
# Explicitly mark as modified
|
||||
session.add(user_limits)
|
||||
|
||||
# Log the updated usage for debugging
|
||||
logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}")
|
||||
logger.info(f"New usage values: {usage}")
|
||||
logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}")
|
||||
|
||||
# Commit and flush to ensure changes are written
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
71
core/limits_utils.py
Normal file
71
core/limits_utils.py
Normal file
@ -0,0 +1,71 @@
|
||||
import logging
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def check_and_increment_limits(auth, limit_type: str, value: int = 1) -> None:
|
||||
"""
|
||||
Check if the user is within limits for an operation and increment usage.
|
||||
|
||||
Args:
|
||||
auth: Authentication context with user_id
|
||||
limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache)
|
||||
value: Value to check against limit (e.g., file size for storage_size)
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user exceeds limits
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
from core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Skip limit checking in self-hosted mode
|
||||
if settings.MODE == "self_hosted":
|
||||
return
|
||||
|
||||
# Check if user_id is available
|
||||
if not auth.user_id:
|
||||
logger.warning("User ID not available in auth context, skipping limit check")
|
||||
return
|
||||
|
||||
# Initialize user service
|
||||
from core.services.user_service import UserService
|
||||
|
||||
user_service = UserService()
|
||||
await user_service.initialize()
|
||||
|
||||
# Check if user is within limits
|
||||
within_limits = await user_service.check_limit(auth.user_id, limit_type, value)
|
||||
|
||||
if not within_limits:
|
||||
# Get tier information for better error message
|
||||
user_data = await user_service.get_user_limits(auth.user_id)
|
||||
tier = user_data.get("tier", "unknown") if user_data else "unknown"
|
||||
|
||||
# Map limit types to appropriate error messages
|
||||
limit_type_messages = {
|
||||
"query": f"Query limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||
"ingest": f"Ingest limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||
"storage_file": f"Storage file count limit exceeded for your {tier} tier. Please delete some files or upgrade.",
|
||||
"storage_size": f"Storage size limit exceeded for your {tier} tier. Please delete some files or upgrade.",
|
||||
"graph": f"Graph creation limit exceeded for your {tier} tier. Please upgrade to create more graphs.",
|
||||
"cache": f"Cache creation limit exceeded for your {tier} tier. Please upgrade to create more caches.",
|
||||
"cache_query": f"Cache query limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||
}
|
||||
|
||||
# Get message for the limit type or use default message
|
||||
detail = limit_type_messages.get(
|
||||
limit_type, f"Limit exceeded for your {tier} tier. Please upgrade or contact support."
|
||||
)
|
||||
|
||||
# Raise the exception with appropriate message
|
||||
raise HTTPException(status_code=429, detail=detail)
|
||||
|
||||
# Record usage asynchronously
|
||||
try:
|
||||
await user_service.record_usage(auth.user_id, limit_type, value)
|
||||
except Exception as e:
|
||||
# Just log if recording usage fails, don't fail the operation
|
||||
logger.error(f"Failed to record usage: {e}")
|
@ -16,3 +16,4 @@ class AuthContext(BaseModel):
|
||||
app_id: Optional[str] = None # uuid, only for developers
|
||||
# TODO: remove permissions, not required here.
|
||||
permissions: Set[str] = {"read"}
|
||||
user_id: Optional[str] = None # ID of the user who owns the app/entity
|
||||
|
@ -57,3 +57,11 @@ class BatchIngestResponse(BaseModel):
|
||||
"""Response model for batch ingestion"""
|
||||
documents: List[Document]
|
||||
errors: List[Dict[str, str]]
|
||||
|
||||
|
||||
class GenerateUriRequest(BaseModel):
|
||||
"""Request model for generating a cloud URI"""
|
||||
app_id: str = Field(..., description="ID of the application")
|
||||
name: str = Field(..., description="Name of the application")
|
||||
user_id: str = Field(..., description="ID of the user who owns the app")
|
||||
expiry_days: int = Field(default=30, description="Number of days until the token expires")
|
||||
|
117
core/models/tiers.py
Normal file
117
core/models/tiers.py
Normal file
@ -0,0 +1,117 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class AccountTier(str, Enum):
|
||||
"""Available account tiers."""
|
||||
|
||||
FREE = "free"
|
||||
PRO = "pro"
|
||||
CUSTOM = "custom"
|
||||
SELF_HOSTED = "self_hosted"
|
||||
|
||||
|
||||
# Tier limits definition - organized by API endpoint usage
|
||||
TIER_LIMITS = {
|
||||
AccountTier.FREE: {
|
||||
# Application limits
|
||||
"app_limit": 1, # Maximum number of applications
|
||||
# Storage limits
|
||||
"storage_file_limit": 100, # Maximum number of files in storage
|
||||
"storage_size_limit_gb": 1, # Maximum storage size in GB
|
||||
"hourly_ingest_limit": 30, # Maximum file/text ingests per hour
|
||||
"monthly_ingest_limit": 100, # Maximum file/text ingests per month
|
||||
# Query limits
|
||||
"hourly_query_limit": 30, # Maximum queries per hour
|
||||
"monthly_query_limit": 1000, # Maximum queries per month
|
||||
# Graph limits
|
||||
"graph_creation_limit": 3, # Maximum number of graphs
|
||||
"hourly_graph_query_limit": 20, # Maximum graph queries per hour
|
||||
"monthly_graph_query_limit": 200, # Maximum graph queries per month
|
||||
# Cache limits
|
||||
"cache_creation_limit": 0, # Maximum number of caches
|
||||
"hourly_cache_query_limit": 0, # Maximum cache queries per hour
|
||||
"monthly_cache_query_limit": 0, # Maximum cache queries per month
|
||||
},
|
||||
AccountTier.PRO: {
|
||||
# Application limits
|
||||
"app_limit": 5, # Maximum number of applications
|
||||
# Storage limits
|
||||
"storage_file_limit": 1000, # Maximum number of files in storage
|
||||
"storage_size_limit_gb": 10, # Maximum storage size in GB
|
||||
"hourly_ingest_limit": 100, # Maximum file/text ingests per hour
|
||||
"monthly_ingest_limit": 3000, # Maximum file/text ingests per month
|
||||
# Query limits
|
||||
"hourly_query_limit": 100, # Maximum queries per hour
|
||||
"monthly_query_limit": 10000, # Maximum queries per month
|
||||
# Graph limits
|
||||
"graph_creation_limit": 10, # Maximum number of graphs
|
||||
"hourly_graph_query_limit": 50, # Maximum graph queries per hour
|
||||
"monthly_graph_query_limit": 1000, # Maximum graph queries per month
|
||||
# Cache limits
|
||||
"cache_creation_limit": 5, # Maximum number of caches
|
||||
"hourly_cache_query_limit": 200, # Maximum cache queries per hour
|
||||
"monthly_cache_query_limit": 5000, # Maximum cache queries per month
|
||||
},
|
||||
AccountTier.CUSTOM: {
|
||||
# Custom tier limits are set on a per-account basis
|
||||
# These are default values that will be overridden
|
||||
# Application limits
|
||||
"app_limit": 10, # Maximum number of applications
|
||||
# Storage limits
|
||||
"storage_file_limit": 10000, # Maximum number of files in storage
|
||||
"storage_size_limit_gb": 100, # Maximum storage size in GB
|
||||
"hourly_ingest_limit": 500, # Maximum file/text ingests per hour
|
||||
"monthly_ingest_limit": 15000, # Maximum file/text ingests per month
|
||||
# Query limits
|
||||
"hourly_query_limit": 500, # Maximum queries per hour
|
||||
"monthly_query_limit": 50000, # Maximum queries per month
|
||||
# Graph limits
|
||||
"graph_creation_limit": 50, # Maximum number of graphs
|
||||
"hourly_graph_query_limit": 200, # Maximum graph queries per hour
|
||||
"monthly_graph_query_limit": 10000, # Maximum graph queries per month
|
||||
# Cache limits
|
||||
"cache_creation_limit": 20, # Maximum number of caches
|
||||
"hourly_cache_query_limit": 1000, # Maximum cache queries per hour
|
||||
"monthly_cache_query_limit": 50000, # Maximum cache queries per month
|
||||
},
|
||||
AccountTier.SELF_HOSTED: {
|
||||
# Self-hosted has no limits
|
||||
# Application limits
|
||||
"app_limit": float("inf"), # Maximum number of applications
|
||||
# Storage limits
|
||||
"storage_file_limit": float("inf"), # Maximum number of files in storage
|
||||
"storage_size_limit_gb": float("inf"), # Maximum storage size in GB
|
||||
"hourly_ingest_limit": float("inf"), # Maximum file/text ingests per hour
|
||||
"monthly_ingest_limit": float("inf"), # Maximum file/text ingests per month
|
||||
# Query limits
|
||||
"hourly_query_limit": float("inf"), # Maximum queries per hour
|
||||
"monthly_query_limit": float("inf"), # Maximum queries per month
|
||||
# Graph limits
|
||||
"graph_creation_limit": float("inf"), # Maximum number of graphs
|
||||
"hourly_graph_query_limit": float("inf"), # Maximum graph queries per hour
|
||||
"monthly_graph_query_limit": float("inf"), # Maximum graph queries per month
|
||||
# Cache limits
|
||||
"cache_creation_limit": float("inf"), # Maximum number of caches
|
||||
"hourly_cache_query_limit": float("inf"), # Maximum cache queries per hour
|
||||
"monthly_cache_query_limit": float("inf"), # Maximum cache queries per month
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_tier_limits(tier: AccountTier, custom_limits: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get limits for a specific account tier.
|
||||
|
||||
Args:
|
||||
tier: The account tier
|
||||
custom_limits: Optional custom limits for CUSTOM tier
|
||||
|
||||
Returns:
|
||||
Dict of limits for the specified tier
|
||||
"""
|
||||
if tier == AccountTier.CUSTOM and custom_limits:
|
||||
# Merge default custom limits with the provided custom limits
|
||||
return {**TIER_LIMITS[tier], **custom_limits}
|
||||
|
||||
return TIER_LIMITS[tier]
|
55
core/models/user_limits.py
Normal file
55
core/models/user_limits.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, UTC
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .tiers import AccountTier
|
||||
|
||||
|
||||
class UserUsage(BaseModel):
|
||||
"""Tracks user's actual usage of the system."""
|
||||
|
||||
# Storage usage
|
||||
storage_file_count: int = 0
|
||||
storage_size_bytes: int = 0
|
||||
|
||||
# Query usage - hourly
|
||||
hourly_query_count: int = 0
|
||||
hourly_query_reset: Optional[datetime] = None
|
||||
|
||||
# Query usage - monthly
|
||||
monthly_query_count: int = 0
|
||||
monthly_query_reset: Optional[datetime] = None
|
||||
|
||||
# Ingest usage - hourly
|
||||
hourly_ingest_count: int = 0
|
||||
hourly_ingest_reset: Optional[datetime] = None
|
||||
|
||||
# Ingest usage - monthly
|
||||
monthly_ingest_count: int = 0
|
||||
monthly_ingest_reset: Optional[datetime] = None
|
||||
|
||||
# Graph usage
|
||||
graph_count: int = 0
|
||||
hourly_graph_query_count: int = 0
|
||||
hourly_graph_query_reset: Optional[datetime] = None
|
||||
monthly_graph_query_count: int = 0
|
||||
monthly_graph_query_reset: Optional[datetime] = None
|
||||
|
||||
# Cache usage
|
||||
cache_count: int = 0
|
||||
hourly_cache_query_count: int = 0
|
||||
hourly_cache_query_reset: Optional[datetime] = None
|
||||
monthly_cache_query_count: int = 0
|
||||
monthly_cache_query_reset: Optional[datetime] = None
|
||||
|
||||
|
||||
class UserLimits(BaseModel):
|
||||
"""Stores user tier and usage information."""
|
||||
|
||||
user_id: str
|
||||
tier: AccountTier = AccountTier.FREE
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
usage: UserUsage = Field(default_factory=UserUsage)
|
||||
custom_limits: Optional[Dict[str, Any]] = None
|
||||
app_ids: list[str] = Field(default_factory=list)
|
@ -329,6 +329,15 @@ class DocumentService:
|
||||
logger.error(f"User {auth.entity_id} does not have write permission")
|
||||
raise PermissionError("User does not have write permission")
|
||||
|
||||
# First check ingest limits if in cloud mode
|
||||
from core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding
|
||||
from core.api import check_and_increment_limits
|
||||
await check_and_increment_limits(auth, "ingest", 1)
|
||||
|
||||
doc = Document(
|
||||
content_type="text/plain",
|
||||
filename=filename,
|
||||
@ -338,6 +347,7 @@ class DocumentService:
|
||||
"readers": [auth.entity_id],
|
||||
"writers": [auth.entity_id],
|
||||
"admins": [auth.entity_id],
|
||||
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
|
||||
},
|
||||
)
|
||||
logger.info(f"Created text document record with ID {doc.external_id}")
|
||||
@ -404,6 +414,20 @@ class DocumentService:
|
||||
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
file_size = len(file_content) # Get file size in bytes for limit checking
|
||||
|
||||
# Check limits before doing any expensive processing
|
||||
from core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.MODE == "cloud" and auth.user_id:
|
||||
# Check limits before proceeding with parsing
|
||||
from core.api import check_and_increment_limits
|
||||
await check_and_increment_limits(auth, "ingest", 1)
|
||||
await check_and_increment_limits(auth, "storage_file", 1)
|
||||
await check_and_increment_limits(auth, "storage_size", file_size)
|
||||
|
||||
# Now proceed with parsing and processing the file
|
||||
file_type = filetype.guess(file_content)
|
||||
|
||||
# Set default mime type for cases where filetype.guess returns None
|
||||
@ -438,8 +462,7 @@ class DocumentService:
|
||||
if modified_text:
|
||||
text = modified_text
|
||||
logger.info("Updated text with modified content from rules")
|
||||
|
||||
# Create document record
|
||||
|
||||
doc = Document(
|
||||
content_type=mime_type,
|
||||
filename=file.filename,
|
||||
@ -449,6 +472,7 @@ class DocumentService:
|
||||
"readers": [auth.entity_id],
|
||||
"writers": [auth.entity_id],
|
||||
"admins": [auth.entity_id],
|
||||
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
|
||||
},
|
||||
additional_metadata=additional_metadata,
|
||||
)
|
||||
|
236
core/services/user_service.py
Normal file
236
core/services/user_service.py
Normal file
@ -0,0 +1,236 @@
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, UTC, timedelta
|
||||
import jwt
|
||||
|
||||
from ..models.tiers import AccountTier, get_tier_limits
|
||||
from ..models.auth import AuthContext
|
||||
from ..config import get_settings
|
||||
from ..database.user_limits_db import UserLimitsDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for managing user limits and usage."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the UserService."""
|
||||
self.settings = get_settings()
|
||||
self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI)
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""Initialize database tables."""
|
||||
return await self.db.initialize()
|
||||
|
||||
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user limits information."""
|
||||
return await self.db.get_user_limits(user_id)
|
||||
|
||||
async def create_user(self, user_id: str) -> bool:
|
||||
"""Create a new user with FREE tier."""
|
||||
return await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||
|
||||
async def update_user_tier(
|
||||
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""Update user tier and custom limits."""
|
||||
return await self.db.update_user_tier(user_id, tier, custom_limits)
|
||||
|
||||
async def register_app(self, user_id: str, app_id: str) -> bool:
|
||||
"""
|
||||
Register an app for a user.
|
||||
|
||||
Creates user limits record if it doesn't exist.
|
||||
"""
|
||||
# First check if user limits exist
|
||||
user_limits = await self.db.get_user_limits(user_id)
|
||||
|
||||
# If user limits don't exist, create them first
|
||||
if not user_limits:
|
||||
logger.info(f"Creating user limits for user {user_id}")
|
||||
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||
if not success:
|
||||
logger.error(f"Failed to create user limits for user {user_id}")
|
||||
return False
|
||||
|
||||
# Now register the app
|
||||
return await self.db.register_app(user_id, app_id)
|
||||
|
||||
async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool:
|
||||
"""
|
||||
Check if a user's operation is within limits when given value is considered.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
limit_type: Type of limit (query, ingest, graph, cache, etc.)
|
||||
value: Value to check (e.g., file size for storage)
|
||||
|
||||
Returns:
|
||||
True if within limits, False if exceeded
|
||||
"""
|
||||
# Skip limit checking for self-hosted mode
|
||||
if self.settings.MODE == "self_hosted":
|
||||
return True
|
||||
|
||||
# Get user limits
|
||||
user_data = await self.db.get_user_limits(user_id)
|
||||
if not user_data:
|
||||
# Create user limits if they don't exist
|
||||
logger.info(f"User {user_id} not found when checking limits - creating limits record")
|
||||
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||
if not success:
|
||||
logger.error(f"Failed to create user limits for user {user_id}")
|
||||
return False
|
||||
# Fetch the newly created limits
|
||||
user_data = await self.db.get_user_limits(user_id)
|
||||
if not user_data:
|
||||
logger.error(f"Failed to retrieve newly created user limits for user {user_id}")
|
||||
return False
|
||||
|
||||
# Get tier limits
|
||||
tier = user_data.get("tier", AccountTier.FREE)
|
||||
tier_limits = get_tier_limits(tier, user_data.get("custom_limits"))
|
||||
|
||||
# Get current usage
|
||||
usage = user_data.get("usage", {})
|
||||
|
||||
# Check specific limit type
|
||||
if limit_type == "query":
|
||||
hourly_limit = tier_limits.get("hourly_query_limit", 0)
|
||||
monthly_limit = tier_limits.get("monthly_query_limit", 0)
|
||||
|
||||
hourly_usage = usage.get("hourly_query_count", 0)
|
||||
monthly_usage = usage.get("monthly_query_count", 0)
|
||||
|
||||
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
||||
|
||||
elif limit_type == "ingest":
|
||||
hourly_limit = tier_limits.get("hourly_ingest_limit", 0)
|
||||
monthly_limit = tier_limits.get("monthly_ingest_limit", 0)
|
||||
|
||||
hourly_usage = usage.get("hourly_ingest_count", 0)
|
||||
monthly_usage = usage.get("monthly_ingest_count", 0)
|
||||
|
||||
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
||||
|
||||
elif limit_type == "storage_file":
|
||||
file_limit = tier_limits.get("storage_file_limit", 0)
|
||||
file_count = usage.get("storage_file_count", 0)
|
||||
|
||||
return file_count + value <= file_limit
|
||||
|
||||
elif limit_type == "storage_size":
|
||||
size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024
|
||||
size_usage = usage.get("storage_size_bytes", 0)
|
||||
|
||||
return size_usage + value <= size_limit_bytes
|
||||
|
||||
elif limit_type == "graph":
|
||||
graph_limit = tier_limits.get("graph_creation_limit", 0)
|
||||
graph_count = usage.get("graph_count", 0)
|
||||
|
||||
return graph_count + value <= graph_limit
|
||||
|
||||
elif limit_type == "cache":
|
||||
cache_limit = tier_limits.get("cache_creation_limit", 0)
|
||||
cache_count = usage.get("cache_count", 0)
|
||||
|
||||
return cache_count + value <= cache_limit
|
||||
|
||||
return True
|
||||
|
||||
async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
|
||||
"""
|
||||
Record usage for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.)
|
||||
increment: Value to increment by
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
# Skip usage recording for self-hosted mode
|
||||
if self.settings.MODE == "self_hosted":
|
||||
return True
|
||||
|
||||
# Check if user limits exist, create if they don't
|
||||
user_data = await self.db.get_user_limits(user_id)
|
||||
if not user_data:
|
||||
logger.info(f"Creating user limits for user {user_id} during usage recording")
|
||||
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||
if not success:
|
||||
logger.error(f"Failed to create user limits for user {user_id}")
|
||||
return False
|
||||
|
||||
return await self.db.update_usage(user_id, usage_type, increment)
|
||||
|
||||
async def generate_cloud_uri(
|
||||
self, user_id: str, app_id: str, name: str, expiry_days: int = 30
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generate a cloud URI for an app.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
app_id: The app ID
|
||||
name: App name for display purposes
|
||||
expiry_days: Number of days until token expires
|
||||
|
||||
Returns:
|
||||
URI string with embedded token, or None if failed
|
||||
"""
|
||||
# Get user limits to check app limit
|
||||
user_limits = await self.db.get_user_limits(user_id)
|
||||
|
||||
# If user doesn't exist yet, create them
|
||||
if not user_limits:
|
||||
await self.create_user(user_id)
|
||||
user_limits = await self.db.get_user_limits(user_id)
|
||||
if not user_limits:
|
||||
logger.error(f"Failed to create user limits for user {user_id}")
|
||||
return None
|
||||
|
||||
# Get tier limits to enforce app limit
|
||||
tier = user_limits.get("tier", AccountTier.FREE)
|
||||
tier_limits = get_tier_limits(tier, user_limits.get("custom_limits"))
|
||||
app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified
|
||||
|
||||
current_apps = user_limits.get("app_ids", [])
|
||||
|
||||
# Skip the limit check if app is already registered
|
||||
if app_id not in current_apps:
|
||||
# Check if user has reached app limit
|
||||
if len(current_apps) >= app_limit:
|
||||
logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}")
|
||||
return None
|
||||
|
||||
# Register the app
|
||||
success = await self.register_app(user_id, app_id)
|
||||
if not success:
|
||||
logger.info(f"Failed to register app {app_id} for user {user_id}")
|
||||
return None
|
||||
|
||||
# Create token payload
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"name": name,
|
||||
"permissions": ["read", "write"],
|
||||
"exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()),
|
||||
"type": "developer",
|
||||
"entity_id": user_id,
|
||||
}
|
||||
|
||||
# Generate token
|
||||
token = jwt.encode(
|
||||
payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
# Generate URI with API domain
|
||||
api_domain = getattr(self.settings, "API_DOMAIN", "api.databridge.ai")
|
||||
uri = f"databridge://{name}:{token}@{api_domain}/{app_id}"
|
||||
|
||||
return uri
|
@ -100,6 +100,7 @@ batch_size = 4096
|
||||
|
||||
[databridge]
|
||||
enable_colpali = true
|
||||
mode = "self_hosted" # "cloud" or "self_hosted"
|
||||
|
||||
[graph]
|
||||
provider = "ollama"
|
||||
|
Loading…
x
Reference in New Issue
Block a user