mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
Add hosted tier limits, cloud uri gen (#59)
This commit is contained in:
parent
f0c44cb8ea
commit
7eb5887d2f
177
core/api.py
177
core/api.py
@ -11,7 +11,8 @@ import logging
|
|||||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||||
from core.completion.openai_completion import OpenAICompletionModel
|
from core.completion.openai_completion import OpenAICompletionModel
|
||||||
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
from core.embedding.ollama_embedding_model import OllamaEmbeddingModel
|
||||||
from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
|
from core.limits_utils import check_and_increment_limits
|
||||||
|
from core.models.request import GenerateUriRequest, RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse
|
||||||
from core.models.completion import ChunkSource, CompletionResponse
|
from core.models.completion import ChunkSource, CompletionResponse
|
||||||
from core.models.documents import Document, DocumentResult, ChunkResult
|
from core.models.documents import Document, DocumentResult, ChunkResult
|
||||||
from core.models.graph import Graph
|
from core.models.graph import Graph
|
||||||
@ -236,6 +237,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
|||||||
entity_type=EntityType(settings.dev_entity_type),
|
entity_type=EntityType(settings.dev_entity_type),
|
||||||
entity_id=settings.dev_entity_id,
|
entity_id=settings.dev_entity_id,
|
||||||
permissions=set(settings.dev_permissions),
|
permissions=set(settings.dev_permissions),
|
||||||
|
user_id=settings.dev_entity_id, # In dev mode, entity_id is also the user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normal token verification flow
|
# Normal token verification flow
|
||||||
@ -255,11 +257,17 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
|||||||
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
|
||||||
|
# Support both "type" and "entity_type" fields for compatibility
|
||||||
|
entity_type_field = payload.get("type") or payload.get("entity_type")
|
||||||
|
if not entity_type_field:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing entity type in token")
|
||||||
|
|
||||||
return AuthContext(
|
return AuthContext(
|
||||||
entity_type=EntityType(payload["type"]),
|
entity_type=EntityType(entity_type_field),
|
||||||
entity_id=payload["entity_id"],
|
entity_id=payload["entity_id"],
|
||||||
app_id=payload.get("app_id"),
|
app_id=payload.get("app_id"),
|
||||||
permissions=set(payload.get("permissions", ["read"])),
|
permissions=set(payload.get("permissions", ["read"])),
|
||||||
|
user_id=payload.get("user_id", payload["entity_id"]), # Use user_id if available, fallback to entity_id
|
||||||
)
|
)
|
||||||
except jwt.InvalidTokenError as e:
|
except jwt.InvalidTokenError as e:
|
||||||
raise HTTPException(status_code=401, detail=str(e))
|
raise HTTPException(status_code=401, detail=str(e))
|
||||||
@ -569,6 +577,11 @@ async def query_completion(
|
|||||||
to enhance retrieval by finding relevant entities and their connected documents.
|
to enhance retrieval by finding relevant entities and their connected documents.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Check query limits if in cloud mode
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding
|
||||||
|
await check_and_increment_limits(auth, "query", 1)
|
||||||
|
|
||||||
async with telemetry.track_operation(
|
async with telemetry.track_operation(
|
||||||
operation_type="query",
|
operation_type="query",
|
||||||
user_id=auth.entity_id,
|
user_id=auth.entity_id,
|
||||||
@ -689,6 +702,7 @@ async def update_document_text(
|
|||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
raise HTTPException(status_code=403, detail=str(e))
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.post("/documents/{document_id}/update_file", response_model=Document)
|
@app.post("/documents/{document_id}/update_file", response_model=Document)
|
||||||
async def update_document_file(
|
async def update_document_file(
|
||||||
document_id: str,
|
document_id: str,
|
||||||
@ -856,6 +870,11 @@ async def create_cache(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create a new cache with specified configuration."""
|
"""Create a new cache with specified configuration."""
|
||||||
try:
|
try:
|
||||||
|
# Check cache creation limits if in cloud mode
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding
|
||||||
|
await check_and_increment_limits(auth, "cache", 1)
|
||||||
|
|
||||||
async with telemetry.track_operation(
|
async with telemetry.track_operation(
|
||||||
operation_type="create_cache",
|
operation_type="create_cache",
|
||||||
user_id=auth.entity_id,
|
user_id=auth.entity_id,
|
||||||
@ -955,6 +974,11 @@ async def query_cache(
|
|||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
"""Query the cache with a prompt."""
|
"""Query the cache with a prompt."""
|
||||||
try:
|
try:
|
||||||
|
# Check cache query limits if in cloud mode
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding
|
||||||
|
await check_and_increment_limits(auth, "cache_query", 1)
|
||||||
|
|
||||||
async with telemetry.track_operation(
|
async with telemetry.track_operation(
|
||||||
operation_type="query_cache",
|
operation_type="query_cache",
|
||||||
user_id=auth.entity_id,
|
user_id=auth.entity_id,
|
||||||
@ -994,6 +1018,11 @@ async def create_graph(
|
|||||||
Graph: The created graph object
|
Graph: The created graph object
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Check graph creation limits if in cloud mode
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding
|
||||||
|
await check_and_increment_limits(auth, "graph", 1)
|
||||||
|
|
||||||
async with telemetry.track_operation(
|
async with telemetry.track_operation(
|
||||||
operation_type="create_graph",
|
operation_type="create_graph",
|
||||||
user_id=auth.entity_id,
|
user_id=auth.entity_id,
|
||||||
@ -1109,3 +1138,147 @@ async def generate_local_uri(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating local URI: {e}")
|
logger.error(f"Error generating local URI: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/cloud/generate_uri", include_in_schema=True)
|
||||||
|
async def generate_cloud_uri(
|
||||||
|
request: GenerateUriRequest,
|
||||||
|
authorization: str = Header(None),
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Generate a URI for cloud hosted applications."""
|
||||||
|
try:
|
||||||
|
app_id = request.app_id
|
||||||
|
name = request.name
|
||||||
|
user_id = request.user_id
|
||||||
|
expiry_days = request.expiry_days
|
||||||
|
|
||||||
|
logger.info(f"Generating cloud URI for app_id={app_id}, name={name}, user_id={user_id}")
|
||||||
|
|
||||||
|
# Verify authorization header before proceeding
|
||||||
|
if not authorization:
|
||||||
|
logger.warning("Missing authorization header")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Missing authorization header",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the token is valid
|
||||||
|
if not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||||
|
|
||||||
|
token = authorization[7:] # Remove "Bearer "
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode the token to ensure it's valid
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# Only allow users to create apps for themselves (or admin)
|
||||||
|
token_user_id = payload.get("user_id")
|
||||||
|
logger.debug(f"Token user ID: {token_user_id}")
|
||||||
|
logger.debug(f"User ID: {user_id}")
|
||||||
|
if not (token_user_id == user_id or "admin" in payload.get("permissions", [])):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="You can only create apps for your own account unless you have admin permissions"
|
||||||
|
)
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
raise HTTPException(status_code=401, detail=str(e))
|
||||||
|
|
||||||
|
# Import UserService here to avoid circular imports
|
||||||
|
from core.services.user_service import UserService
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
|
# Initialize user service if needed
|
||||||
|
await user_service.initialize()
|
||||||
|
|
||||||
|
# Clean name
|
||||||
|
name = name.replace(" ", "_").lower()
|
||||||
|
|
||||||
|
# Check if the user is within app limit and generate URI
|
||||||
|
uri = await user_service.generate_cloud_uri(user_id, app_id, name, expiry_days)
|
||||||
|
|
||||||
|
if not uri:
|
||||||
|
logger.info("Application limit reached for this account tier with user_id: %s", user_id)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Application limit reached for this account tier"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"uri": uri, "app_id": app_id}
|
||||||
|
except HTTPException:
|
||||||
|
# Re-raise HTTP exceptions
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating cloud URI: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/user/upgrade", include_in_schema=True)
|
||||||
|
async def upgrade_user_tier(
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
custom_limits: Optional[Dict[str, Any]] = None,
|
||||||
|
authorization: str = Header(None),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Upgrade a user to a higher tier."""
|
||||||
|
try:
|
||||||
|
# Verify admin authorization
|
||||||
|
if not authorization:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Missing authorization header",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||||
|
|
||||||
|
token = authorization[7:] # Remove "Bearer "
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode token
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# Only allow admins to upgrade users
|
||||||
|
if "admin" not in payload.get("permissions", []):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Admin permission required"
|
||||||
|
)
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
raise HTTPException(status_code=401, detail=str(e))
|
||||||
|
|
||||||
|
# Validate tier
|
||||||
|
from core.models.tiers import AccountTier
|
||||||
|
try:
|
||||||
|
account_tier = AccountTier(tier)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid tier: {tier}")
|
||||||
|
|
||||||
|
# Upgrade user
|
||||||
|
from core.services.user_service import UserService
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
|
# Initialize user service
|
||||||
|
await user_service.initialize()
|
||||||
|
|
||||||
|
# Update user tier
|
||||||
|
success = await user_service.update_user_tier(user_id, tier, custom_limits)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="User not found or upgrade failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"tier": tier,
|
||||||
|
"message": f"User upgraded to {tier} tier"
|
||||||
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error upgrading user tier: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
@ -91,6 +91,9 @@ class Settings(BaseSettings):
|
|||||||
# Colpali configuration
|
# Colpali configuration
|
||||||
ENABLE_COLPALI: bool
|
ENABLE_COLPALI: bool
|
||||||
|
|
||||||
|
# Mode configuration
|
||||||
|
MODE: Literal["cloud", "self_hosted"] = "cloud"
|
||||||
|
|
||||||
# Telemetry configuration
|
# Telemetry configuration
|
||||||
TELEMETRY_ENABLED: bool = True
|
TELEMETRY_ENABLED: bool = True
|
||||||
HONEYCOMB_ENABLED: bool = True
|
HONEYCOMB_ENABLED: bool = True
|
||||||
@ -291,6 +294,7 @@ def get_settings() -> Settings:
|
|||||||
# load databridge config
|
# load databridge config
|
||||||
databridge_config = {
|
databridge_config = {
|
||||||
"ENABLE_COLPALI": config["databridge"]["enable_colpali"],
|
"ENABLE_COLPALI": config["databridge"]["enable_colpali"],
|
||||||
|
"MODE": config["databridge"].get("mode", "cloud"), # Default to "cloud" mode
|
||||||
}
|
}
|
||||||
|
|
||||||
# load graph config
|
# load graph config
|
||||||
|
@ -494,6 +494,15 @@ class PostgresDatabase(BaseDatabase):
|
|||||||
if auth.entity_type == "DEVELOPER" and auth.app_id:
|
if auth.entity_type == "DEVELOPER" and auth.app_id:
|
||||||
# Add app-specific access for developers
|
# Add app-specific access for developers
|
||||||
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
|
filters.append(f"access_control->'app_access' ? '{auth.app_id}'")
|
||||||
|
|
||||||
|
# Add user_id filter in cloud mode
|
||||||
|
if auth.user_id:
|
||||||
|
from core.config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if settings.MODE == "cloud":
|
||||||
|
# Filter by user_id in access_control
|
||||||
|
filters.append(f"access_control->>'user_id' = '{auth.user_id}'")
|
||||||
|
|
||||||
return " OR ".join(filters)
|
return " OR ".join(filters)
|
||||||
|
|
||||||
|
359
core/database/user_limits_db.py
Normal file
359
core/database/user_limits_db.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from datetime import datetime, UTC, timedelta
|
||||||
|
import logging
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
from sqlalchemy import Column, String, Index, select, text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class UserLimitsModel(Base):
|
||||||
|
"""SQLAlchemy model for user limits data."""
|
||||||
|
|
||||||
|
__tablename__ = "user_limits"
|
||||||
|
|
||||||
|
user_id = Column(String, primary_key=True)
|
||||||
|
tier = Column(String, nullable=False) # FREE, PRO, CUSTOM, SELF_HOSTED
|
||||||
|
custom_limits = Column(JSONB, nullable=True)
|
||||||
|
usage = Column(JSONB, default=dict) # Holds all usage counters
|
||||||
|
app_ids = Column(JSONB, default=list) # List of app IDs registered by this user
|
||||||
|
created_at = Column(String) # ISO format string
|
||||||
|
updated_at = Column(String) # ISO format string
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
__table_args__ = (Index("idx_user_tier", "tier"),)
|
||||||
|
|
||||||
|
|
||||||
|
class UserLimitsDatabase:
|
||||||
|
"""Database operations for user limits."""
|
||||||
|
|
||||||
|
def __init__(self, uri: str):
|
||||||
|
"""Initialize database connection."""
|
||||||
|
self.engine = create_async_engine(uri)
|
||||||
|
self.async_session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize database tables and indexes."""
|
||||||
|
if self._initialized:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("Initializing user limits database tables...")
|
||||||
|
# Create tables if they don't exist
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("User limits database tables initialized successfully")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize user limits database: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get user limits for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to get limits for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with user limits if found, None otherwise
|
||||||
|
"""
|
||||||
|
async with self.async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_limits = result.scalars().first()
|
||||||
|
|
||||||
|
if not user_limits:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_limits.user_id,
|
||||||
|
"tier": user_limits.tier,
|
||||||
|
"custom_limits": user_limits.custom_limits,
|
||||||
|
"usage": user_limits.usage,
|
||||||
|
"app_ids": user_limits.app_ids,
|
||||||
|
"created_at": user_limits.created_at,
|
||||||
|
"updated_at": user_limits.updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def create_user_limits(self, user_id: str, tier: str = "free") -> bool:
|
||||||
|
"""
|
||||||
|
Create user limits record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
tier: Initial tier (defaults to "free")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
# Check if already exists
|
||||||
|
result = await session.execute(
|
||||||
|
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
if result.scalars().first():
|
||||||
|
return True # Already exists
|
||||||
|
|
||||||
|
# Create new record with properly initialized JSONB columns
|
||||||
|
# Create JSON strings and parse them for consistency
|
||||||
|
usage_json = json.dumps({
|
||||||
|
"storage_file_count": 0,
|
||||||
|
"storage_size_bytes": 0,
|
||||||
|
"hourly_query_count": 0,
|
||||||
|
"hourly_query_reset": now,
|
||||||
|
"monthly_query_count": 0,
|
||||||
|
"monthly_query_reset": now,
|
||||||
|
"hourly_ingest_count": 0,
|
||||||
|
"hourly_ingest_reset": now,
|
||||||
|
"monthly_ingest_count": 0,
|
||||||
|
"monthly_ingest_reset": now,
|
||||||
|
"graph_count": 0,
|
||||||
|
"cache_count": 0,
|
||||||
|
})
|
||||||
|
app_ids_json = json.dumps([]) # Empty array but as JSON string
|
||||||
|
|
||||||
|
# Create the model with the JSON parsed
|
||||||
|
user_limits = UserLimitsModel(
|
||||||
|
user_id=user_id,
|
||||||
|
tier=tier,
|
||||||
|
usage=json.loads(usage_json),
|
||||||
|
app_ids=json.loads(app_ids_json),
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(user_limits)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create user limits: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def update_user_tier(
|
||||||
|
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Update user tier and custom limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
tier: New tier
|
||||||
|
custom_limits: Optional custom limits for CUSTOM tier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_limits = result.scalars().first()
|
||||||
|
|
||||||
|
if not user_limits:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_limits.tier = tier
|
||||||
|
user_limits.custom_limits = custom_limits
|
||||||
|
user_limits.updated_at = now
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update user tier: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def register_app(self, user_id: str, app_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Register an app for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
app_id: The app ID to register
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
# First check if user exists
|
||||||
|
result = await session.execute(
|
||||||
|
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_limits = result.scalars().first()
|
||||||
|
|
||||||
|
if not user_limits:
|
||||||
|
logger.error(f"User {user_id} not found in register_app")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use raw SQL with jsonb_array_append to update the app_ids array
|
||||||
|
# This is the most reliable way to append to a JSONB array in PostgreSQL
|
||||||
|
query = text(
|
||||||
|
"""
|
||||||
|
UPDATE user_limits
|
||||||
|
SET
|
||||||
|
app_ids = CASE
|
||||||
|
WHEN NOT (app_ids ? :app_id) -- Check if app_id is not in the array
|
||||||
|
THEN app_ids || :app_id_json -- Append it if not present
|
||||||
|
ELSE app_ids -- Keep it unchanged if already present
|
||||||
|
END,
|
||||||
|
updated_at = :now
|
||||||
|
WHERE user_id = :user_id
|
||||||
|
RETURNING app_ids;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
result = await session.execute(
|
||||||
|
query,
|
||||||
|
{
|
||||||
|
"app_id": app_id, # For the check
|
||||||
|
"app_id_json": f'["{app_id}"]', # JSON array format for appending
|
||||||
|
"now": now,
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log the result for debugging
|
||||||
|
updated_app_ids = result.scalar()
|
||||||
|
logger.info(f"Updated app_ids for user {user_id}: {updated_app_ids}")
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register app: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def update_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Update usage counter for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
usage_type: Type of usage to update
|
||||||
|
increment: Value to increment by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
now_iso = now.isoformat()
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(UserLimitsModel).where(UserLimitsModel.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_limits = result.scalars().first()
|
||||||
|
|
||||||
|
if not user_limits:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create a new dictionary to force SQLAlchemy to detect the change
|
||||||
|
usage = dict(user_limits.usage) if user_limits.usage else {}
|
||||||
|
|
||||||
|
# Handle different usage types
|
||||||
|
if usage_type == "query":
|
||||||
|
# Check hourly reset
|
||||||
|
hourly_reset_str = usage.get("hourly_query_reset", "")
|
||||||
|
if hourly_reset_str:
|
||||||
|
hourly_reset = datetime.fromisoformat(hourly_reset_str)
|
||||||
|
if now > hourly_reset + timedelta(hours=1):
|
||||||
|
usage["hourly_query_count"] = increment
|
||||||
|
usage["hourly_query_reset"] = now_iso
|
||||||
|
else:
|
||||||
|
usage["hourly_query_count"] = usage.get("hourly_query_count", 0) + increment
|
||||||
|
else:
|
||||||
|
usage["hourly_query_count"] = increment
|
||||||
|
usage["hourly_query_reset"] = now_iso
|
||||||
|
|
||||||
|
# Check monthly reset
|
||||||
|
monthly_reset_str = usage.get("monthly_query_reset", "")
|
||||||
|
if monthly_reset_str:
|
||||||
|
monthly_reset = datetime.fromisoformat(monthly_reset_str)
|
||||||
|
if now > monthly_reset + timedelta(days=30):
|
||||||
|
usage["monthly_query_count"] = increment
|
||||||
|
usage["monthly_query_reset"] = now_iso
|
||||||
|
else:
|
||||||
|
usage["monthly_query_count"] = usage.get("monthly_query_count", 0) + increment
|
||||||
|
else:
|
||||||
|
usage["monthly_query_count"] = increment
|
||||||
|
usage["monthly_query_reset"] = now_iso
|
||||||
|
|
||||||
|
elif usage_type == "ingest":
|
||||||
|
# Similar pattern for ingest
|
||||||
|
hourly_reset_str = usage.get("hourly_ingest_reset", "")
|
||||||
|
if hourly_reset_str:
|
||||||
|
hourly_reset = datetime.fromisoformat(hourly_reset_str)
|
||||||
|
if now > hourly_reset + timedelta(hours=1):
|
||||||
|
usage["hourly_ingest_count"] = increment
|
||||||
|
usage["hourly_ingest_reset"] = now_iso
|
||||||
|
else:
|
||||||
|
usage["hourly_ingest_count"] = usage.get("hourly_ingest_count", 0) + increment
|
||||||
|
else:
|
||||||
|
usage["hourly_ingest_count"] = increment
|
||||||
|
usage["hourly_ingest_reset"] = now_iso
|
||||||
|
|
||||||
|
monthly_reset_str = usage.get("monthly_ingest_reset", "")
|
||||||
|
if monthly_reset_str:
|
||||||
|
monthly_reset = datetime.fromisoformat(monthly_reset_str)
|
||||||
|
if now > monthly_reset + timedelta(days=30):
|
||||||
|
usage["monthly_ingest_count"] = increment
|
||||||
|
usage["monthly_ingest_reset"] = now_iso
|
||||||
|
else:
|
||||||
|
usage["monthly_ingest_count"] = usage.get("monthly_ingest_count", 0) + increment
|
||||||
|
else:
|
||||||
|
usage["monthly_ingest_count"] = increment
|
||||||
|
usage["monthly_ingest_reset"] = now_iso
|
||||||
|
|
||||||
|
elif usage_type == "storage_file":
|
||||||
|
usage["storage_file_count"] = usage.get("storage_file_count", 0) + increment
|
||||||
|
|
||||||
|
elif usage_type == "storage_size":
|
||||||
|
usage["storage_size_bytes"] = usage.get("storage_size_bytes", 0) + increment
|
||||||
|
|
||||||
|
elif usage_type == "graph":
|
||||||
|
usage["graph_count"] = usage.get("graph_count", 0) + increment
|
||||||
|
|
||||||
|
elif usage_type == "cache":
|
||||||
|
usage["cache_count"] = usage.get("cache_count", 0) + increment
|
||||||
|
|
||||||
|
# Force SQLAlchemy to recognize the change by assigning a new dict
|
||||||
|
user_limits.usage = usage
|
||||||
|
user_limits.updated_at = now_iso
|
||||||
|
|
||||||
|
# Explicitly mark as modified
|
||||||
|
session.add(user_limits)
|
||||||
|
|
||||||
|
# Log the updated usage for debugging
|
||||||
|
logger.info(f"Updated usage for user {user_id}, type: {usage_type}, value: {increment}")
|
||||||
|
logger.info(f"New usage values: {usage}")
|
||||||
|
logger.info(f"About to commit: user_id={user_id}, usage={user_limits.usage}")
|
||||||
|
|
||||||
|
# Commit and flush to ensure changes are written
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update usage: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
71
core/limits_utils.py
Normal file
71
core/limits_utils.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_and_increment_limits(auth, limit_type: str, value: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Check if the user is within limits for an operation and increment usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth: Authentication context with user_id
|
||||||
|
limit_type: Type of limit to check (query, ingest, storage_file, storage_size, graph, cache)
|
||||||
|
value: Value to check against limit (e.g., file size for storage_size)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the user exceeds limits
|
||||||
|
"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from core.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Skip limit checking in self-hosted mode
|
||||||
|
if settings.MODE == "self_hosted":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if user_id is available
|
||||||
|
if not auth.user_id:
|
||||||
|
logger.warning("User ID not available in auth context, skipping limit check")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize user service
|
||||||
|
from core.services.user_service import UserService
|
||||||
|
|
||||||
|
user_service = UserService()
|
||||||
|
await user_service.initialize()
|
||||||
|
|
||||||
|
# Check if user is within limits
|
||||||
|
within_limits = await user_service.check_limit(auth.user_id, limit_type, value)
|
||||||
|
|
||||||
|
if not within_limits:
|
||||||
|
# Get tier information for better error message
|
||||||
|
user_data = await user_service.get_user_limits(auth.user_id)
|
||||||
|
tier = user_data.get("tier", "unknown") if user_data else "unknown"
|
||||||
|
|
||||||
|
# Map limit types to appropriate error messages
|
||||||
|
limit_type_messages = {
|
||||||
|
"query": f"Query limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||||
|
"ingest": f"Ingest limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||||
|
"storage_file": f"Storage file count limit exceeded for your {tier} tier. Please delete some files or upgrade.",
|
||||||
|
"storage_size": f"Storage size limit exceeded for your {tier} tier. Please delete some files or upgrade.",
|
||||||
|
"graph": f"Graph creation limit exceeded for your {tier} tier. Please upgrade to create more graphs.",
|
||||||
|
"cache": f"Cache creation limit exceeded for your {tier} tier. Please upgrade to create more caches.",
|
||||||
|
"cache_query": f"Cache query limit exceeded for your {tier} tier. Please upgrade or try again later.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get message for the limit type or use default message
|
||||||
|
detail = limit_type_messages.get(
|
||||||
|
limit_type, f"Limit exceeded for your {tier} tier. Please upgrade or contact support."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raise the exception with appropriate message
|
||||||
|
raise HTTPException(status_code=429, detail=detail)
|
||||||
|
|
||||||
|
# Record usage asynchronously
|
||||||
|
try:
|
||||||
|
await user_service.record_usage(auth.user_id, limit_type, value)
|
||||||
|
except Exception as e:
|
||||||
|
# Just log if recording usage fails, don't fail the operation
|
||||||
|
logger.error(f"Failed to record usage: {e}")
|
@ -16,3 +16,4 @@ class AuthContext(BaseModel):
|
|||||||
app_id: Optional[str] = None # uuid, only for developers
|
app_id: Optional[str] = None # uuid, only for developers
|
||||||
# TODO: remove permissions, not required here.
|
# TODO: remove permissions, not required here.
|
||||||
permissions: Set[str] = {"read"}
|
permissions: Set[str] = {"read"}
|
||||||
|
user_id: Optional[str] = None # ID of the user who owns the app/entity
|
||||||
|
@ -57,3 +57,11 @@ class BatchIngestResponse(BaseModel):
|
|||||||
"""Response model for batch ingestion"""
|
"""Response model for batch ingestion"""
|
||||||
documents: List[Document]
|
documents: List[Document]
|
||||||
errors: List[Dict[str, str]]
|
errors: List[Dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateUriRequest(BaseModel):
|
||||||
|
"""Request model for generating a cloud URI"""
|
||||||
|
app_id: str = Field(..., description="ID of the application")
|
||||||
|
name: str = Field(..., description="Name of the application")
|
||||||
|
user_id: str = Field(..., description="ID of the user who owns the app")
|
||||||
|
expiry_days: int = Field(default=30, description="Number of days until the token expires")
|
||||||
|
117
core/models/tiers.py
Normal file
117
core/models/tiers.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class AccountTier(str, Enum):
|
||||||
|
"""Available account tiers."""
|
||||||
|
|
||||||
|
FREE = "free"
|
||||||
|
PRO = "pro"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
SELF_HOSTED = "self_hosted"
|
||||||
|
|
||||||
|
|
||||||
|
# Tier limits definition - organized by API endpoint usage
|
||||||
|
TIER_LIMITS = {
|
||||||
|
AccountTier.FREE: {
|
||||||
|
# Application limits
|
||||||
|
"app_limit": 1, # Maximum number of applications
|
||||||
|
# Storage limits
|
||||||
|
"storage_file_limit": 100, # Maximum number of files in storage
|
||||||
|
"storage_size_limit_gb": 1, # Maximum storage size in GB
|
||||||
|
"hourly_ingest_limit": 30, # Maximum file/text ingests per hour
|
||||||
|
"monthly_ingest_limit": 100, # Maximum file/text ingests per month
|
||||||
|
# Query limits
|
||||||
|
"hourly_query_limit": 30, # Maximum queries per hour
|
||||||
|
"monthly_query_limit": 1000, # Maximum queries per month
|
||||||
|
# Graph limits
|
||||||
|
"graph_creation_limit": 3, # Maximum number of graphs
|
||||||
|
"hourly_graph_query_limit": 20, # Maximum graph queries per hour
|
||||||
|
"monthly_graph_query_limit": 200, # Maximum graph queries per month
|
||||||
|
# Cache limits
|
||||||
|
"cache_creation_limit": 0, # Maximum number of caches
|
||||||
|
"hourly_cache_query_limit": 0, # Maximum cache queries per hour
|
||||||
|
"monthly_cache_query_limit": 0, # Maximum cache queries per month
|
||||||
|
},
|
||||||
|
AccountTier.PRO: {
|
||||||
|
# Application limits
|
||||||
|
"app_limit": 5, # Maximum number of applications
|
||||||
|
# Storage limits
|
||||||
|
"storage_file_limit": 1000, # Maximum number of files in storage
|
||||||
|
"storage_size_limit_gb": 10, # Maximum storage size in GB
|
||||||
|
"hourly_ingest_limit": 100, # Maximum file/text ingests per hour
|
||||||
|
"monthly_ingest_limit": 3000, # Maximum file/text ingests per month
|
||||||
|
# Query limits
|
||||||
|
"hourly_query_limit": 100, # Maximum queries per hour
|
||||||
|
"monthly_query_limit": 10000, # Maximum queries per month
|
||||||
|
# Graph limits
|
||||||
|
"graph_creation_limit": 10, # Maximum number of graphs
|
||||||
|
"hourly_graph_query_limit": 50, # Maximum graph queries per hour
|
||||||
|
"monthly_graph_query_limit": 1000, # Maximum graph queries per month
|
||||||
|
# Cache limits
|
||||||
|
"cache_creation_limit": 5, # Maximum number of caches
|
||||||
|
"hourly_cache_query_limit": 200, # Maximum cache queries per hour
|
||||||
|
"monthly_cache_query_limit": 5000, # Maximum cache queries per month
|
||||||
|
},
|
||||||
|
AccountTier.CUSTOM: {
|
||||||
|
# Custom tier limits are set on a per-account basis
|
||||||
|
# These are default values that will be overridden
|
||||||
|
# Application limits
|
||||||
|
"app_limit": 10, # Maximum number of applications
|
||||||
|
# Storage limits
|
||||||
|
"storage_file_limit": 10000, # Maximum number of files in storage
|
||||||
|
"storage_size_limit_gb": 100, # Maximum storage size in GB
|
||||||
|
"hourly_ingest_limit": 500, # Maximum file/text ingests per hour
|
||||||
|
"monthly_ingest_limit": 15000, # Maximum file/text ingests per month
|
||||||
|
# Query limits
|
||||||
|
"hourly_query_limit": 500, # Maximum queries per hour
|
||||||
|
"monthly_query_limit": 50000, # Maximum queries per month
|
||||||
|
# Graph limits
|
||||||
|
"graph_creation_limit": 50, # Maximum number of graphs
|
||||||
|
"hourly_graph_query_limit": 200, # Maximum graph queries per hour
|
||||||
|
"monthly_graph_query_limit": 10000, # Maximum graph queries per month
|
||||||
|
# Cache limits
|
||||||
|
"cache_creation_limit": 20, # Maximum number of caches
|
||||||
|
"hourly_cache_query_limit": 1000, # Maximum cache queries per hour
|
||||||
|
"monthly_cache_query_limit": 50000, # Maximum cache queries per month
|
||||||
|
},
|
||||||
|
AccountTier.SELF_HOSTED: {
|
||||||
|
# Self-hosted has no limits
|
||||||
|
# Application limits
|
||||||
|
"app_limit": float("inf"), # Maximum number of applications
|
||||||
|
# Storage limits
|
||||||
|
"storage_file_limit": float("inf"), # Maximum number of files in storage
|
||||||
|
"storage_size_limit_gb": float("inf"), # Maximum storage size in GB
|
||||||
|
"hourly_ingest_limit": float("inf"), # Maximum file/text ingests per hour
|
||||||
|
"monthly_ingest_limit": float("inf"), # Maximum file/text ingests per month
|
||||||
|
# Query limits
|
||||||
|
"hourly_query_limit": float("inf"), # Maximum queries per hour
|
||||||
|
"monthly_query_limit": float("inf"), # Maximum queries per month
|
||||||
|
# Graph limits
|
||||||
|
"graph_creation_limit": float("inf"), # Maximum number of graphs
|
||||||
|
"hourly_graph_query_limit": float("inf"), # Maximum graph queries per hour
|
||||||
|
"monthly_graph_query_limit": float("inf"), # Maximum graph queries per month
|
||||||
|
# Cache limits
|
||||||
|
"cache_creation_limit": float("inf"), # Maximum number of caches
|
||||||
|
"hourly_cache_query_limit": float("inf"), # Maximum cache queries per hour
|
||||||
|
"monthly_cache_query_limit": float("inf"), # Maximum cache queries per month
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tier_limits(tier: AccountTier, custom_limits: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get limits for a specific account tier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tier: The account tier
|
||||||
|
custom_limits: Optional custom limits for CUSTOM tier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of limits for the specified tier
|
||||||
|
"""
|
||||||
|
if tier == AccountTier.CUSTOM and custom_limits:
|
||||||
|
# Merge default custom limits with the provided custom limits
|
||||||
|
return {**TIER_LIMITS[tier], **custom_limits}
|
||||||
|
|
||||||
|
return TIER_LIMITS[tier]
|
55
core/models/user_limits.py
Normal file
55
core/models/user_limits.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .tiers import AccountTier
|
||||||
|
|
||||||
|
|
||||||
|
class UserUsage(BaseModel):
|
||||||
|
"""Tracks user's actual usage of the system."""
|
||||||
|
|
||||||
|
# Storage usage
|
||||||
|
storage_file_count: int = 0
|
||||||
|
storage_size_bytes: int = 0
|
||||||
|
|
||||||
|
# Query usage - hourly
|
||||||
|
hourly_query_count: int = 0
|
||||||
|
hourly_query_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Query usage - monthly
|
||||||
|
monthly_query_count: int = 0
|
||||||
|
monthly_query_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Ingest usage - hourly
|
||||||
|
hourly_ingest_count: int = 0
|
||||||
|
hourly_ingest_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Ingest usage - monthly
|
||||||
|
monthly_ingest_count: int = 0
|
||||||
|
monthly_ingest_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Graph usage
|
||||||
|
graph_count: int = 0
|
||||||
|
hourly_graph_query_count: int = 0
|
||||||
|
hourly_graph_query_reset: Optional[datetime] = None
|
||||||
|
monthly_graph_query_count: int = 0
|
||||||
|
monthly_graph_query_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Cache usage
|
||||||
|
cache_count: int = 0
|
||||||
|
hourly_cache_query_count: int = 0
|
||||||
|
hourly_cache_query_reset: Optional[datetime] = None
|
||||||
|
monthly_cache_query_count: int = 0
|
||||||
|
monthly_cache_query_reset: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserLimits(BaseModel):
|
||||||
|
"""Stores user tier and usage information."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
tier: AccountTier = AccountTier.FREE
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
usage: UserUsage = Field(default_factory=UserUsage)
|
||||||
|
custom_limits: Optional[Dict[str, Any]] = None
|
||||||
|
app_ids: list[str] = Field(default_factory=list)
|
@ -329,6 +329,15 @@ class DocumentService:
|
|||||||
logger.error(f"User {auth.entity_id} does not have write permission")
|
logger.error(f"User {auth.entity_id} does not have write permission")
|
||||||
raise PermissionError("User does not have write permission")
|
raise PermissionError("User does not have write permission")
|
||||||
|
|
||||||
|
# First check ingest limits if in cloud mode
|
||||||
|
from core.config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding
|
||||||
|
from core.api import check_and_increment_limits
|
||||||
|
await check_and_increment_limits(auth, "ingest", 1)
|
||||||
|
|
||||||
doc = Document(
|
doc = Document(
|
||||||
content_type="text/plain",
|
content_type="text/plain",
|
||||||
filename=filename,
|
filename=filename,
|
||||||
@ -338,6 +347,7 @@ class DocumentService:
|
|||||||
"readers": [auth.entity_id],
|
"readers": [auth.entity_id],
|
||||||
"writers": [auth.entity_id],
|
"writers": [auth.entity_id],
|
||||||
"admins": [auth.entity_id],
|
"admins": [auth.entity_id],
|
||||||
|
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Created text document record with ID {doc.external_id}")
|
logger.info(f"Created text document record with ID {doc.external_id}")
|
||||||
@ -404,6 +414,20 @@ class DocumentService:
|
|||||||
|
|
||||||
# Read file content
|
# Read file content
|
||||||
file_content = await file.read()
|
file_content = await file.read()
|
||||||
|
file_size = len(file_content) # Get file size in bytes for limit checking
|
||||||
|
|
||||||
|
# Check limits before doing any expensive processing
|
||||||
|
from core.config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if settings.MODE == "cloud" and auth.user_id:
|
||||||
|
# Check limits before proceeding with parsing
|
||||||
|
from core.api import check_and_increment_limits
|
||||||
|
await check_and_increment_limits(auth, "ingest", 1)
|
||||||
|
await check_and_increment_limits(auth, "storage_file", 1)
|
||||||
|
await check_and_increment_limits(auth, "storage_size", file_size)
|
||||||
|
|
||||||
|
# Now proceed with parsing and processing the file
|
||||||
file_type = filetype.guess(file_content)
|
file_type = filetype.guess(file_content)
|
||||||
|
|
||||||
# Set default mime type for cases where filetype.guess returns None
|
# Set default mime type for cases where filetype.guess returns None
|
||||||
@ -438,8 +462,7 @@ class DocumentService:
|
|||||||
if modified_text:
|
if modified_text:
|
||||||
text = modified_text
|
text = modified_text
|
||||||
logger.info("Updated text with modified content from rules")
|
logger.info("Updated text with modified content from rules")
|
||||||
|
|
||||||
# Create document record
|
|
||||||
doc = Document(
|
doc = Document(
|
||||||
content_type=mime_type,
|
content_type=mime_type,
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
@ -449,6 +472,7 @@ class DocumentService:
|
|||||||
"readers": [auth.entity_id],
|
"readers": [auth.entity_id],
|
||||||
"writers": [auth.entity_id],
|
"writers": [auth.entity_id],
|
||||||
"admins": [auth.entity_id],
|
"admins": [auth.entity_id],
|
||||||
|
"user_id": [auth.user_id] if auth.user_id else [], # Add user_id to access control for filtering (as a list)
|
||||||
},
|
},
|
||||||
additional_metadata=additional_metadata,
|
additional_metadata=additional_metadata,
|
||||||
)
|
)
|
||||||
|
236
core/services/user_service.py
Normal file
236
core/services/user_service.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime, UTC, timedelta
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from ..models.tiers import AccountTier, get_tier_limits
|
||||||
|
from ..models.auth import AuthContext
|
||||||
|
from ..config import get_settings
|
||||||
|
from ..database.user_limits_db import UserLimitsDatabase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service for managing user limits and usage."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the UserService."""
|
||||||
|
self.settings = get_settings()
|
||||||
|
self.db = UserLimitsDatabase(uri=self.settings.POSTGRES_URI)
|
||||||
|
|
||||||
|
async def initialize(self) -> bool:
|
||||||
|
"""Initialize database tables."""
|
||||||
|
return await self.db.initialize()
|
||||||
|
|
||||||
|
async def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get user limits information."""
|
||||||
|
return await self.db.get_user_limits(user_id)
|
||||||
|
|
||||||
|
async def create_user(self, user_id: str) -> bool:
|
||||||
|
"""Create a new user with FREE tier."""
|
||||||
|
return await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||||
|
|
||||||
|
async def update_user_tier(
|
||||||
|
self, user_id: str, tier: str, custom_limits: Optional[Dict[str, Any]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Update user tier and custom limits."""
|
||||||
|
return await self.db.update_user_tier(user_id, tier, custom_limits)
|
||||||
|
|
||||||
|
async def register_app(self, user_id: str, app_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Register an app for a user.
|
||||||
|
|
||||||
|
Creates user limits record if it doesn't exist.
|
||||||
|
"""
|
||||||
|
# First check if user limits exist
|
||||||
|
user_limits = await self.db.get_user_limits(user_id)
|
||||||
|
|
||||||
|
# If user limits don't exist, create them first
|
||||||
|
if not user_limits:
|
||||||
|
logger.info(f"Creating user limits for user {user_id}")
|
||||||
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||||
|
if not success:
|
||||||
|
logger.error(f"Failed to create user limits for user {user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Now register the app
|
||||||
|
return await self.db.register_app(user_id, app_id)
|
||||||
|
|
||||||
|
async def check_limit(self, user_id: str, limit_type: str, value: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a user's operation is within limits when given value is considered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to check
|
||||||
|
limit_type: Type of limit (query, ingest, graph, cache, etc.)
|
||||||
|
value: Value to check (e.g., file size for storage)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if within limits, False if exceeded
|
||||||
|
"""
|
||||||
|
# Skip limit checking for self-hosted mode
|
||||||
|
if self.settings.MODE == "self_hosted":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Get user limits
|
||||||
|
user_data = await self.db.get_user_limits(user_id)
|
||||||
|
if not user_data:
|
||||||
|
# Create user limits if they don't exist
|
||||||
|
logger.info(f"User {user_id} not found when checking limits - creating limits record")
|
||||||
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||||
|
if not success:
|
||||||
|
logger.error(f"Failed to create user limits for user {user_id}")
|
||||||
|
return False
|
||||||
|
# Fetch the newly created limits
|
||||||
|
user_data = await self.db.get_user_limits(user_id)
|
||||||
|
if not user_data:
|
||||||
|
logger.error(f"Failed to retrieve newly created user limits for user {user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get tier limits
|
||||||
|
tier = user_data.get("tier", AccountTier.FREE)
|
||||||
|
tier_limits = get_tier_limits(tier, user_data.get("custom_limits"))
|
||||||
|
|
||||||
|
# Get current usage
|
||||||
|
usage = user_data.get("usage", {})
|
||||||
|
|
||||||
|
# Check specific limit type
|
||||||
|
if limit_type == "query":
|
||||||
|
hourly_limit = tier_limits.get("hourly_query_limit", 0)
|
||||||
|
monthly_limit = tier_limits.get("monthly_query_limit", 0)
|
||||||
|
|
||||||
|
hourly_usage = usage.get("hourly_query_count", 0)
|
||||||
|
monthly_usage = usage.get("monthly_query_count", 0)
|
||||||
|
|
||||||
|
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
||||||
|
|
||||||
|
elif limit_type == "ingest":
|
||||||
|
hourly_limit = tier_limits.get("hourly_ingest_limit", 0)
|
||||||
|
monthly_limit = tier_limits.get("monthly_ingest_limit", 0)
|
||||||
|
|
||||||
|
hourly_usage = usage.get("hourly_ingest_count", 0)
|
||||||
|
monthly_usage = usage.get("monthly_ingest_count", 0)
|
||||||
|
|
||||||
|
return hourly_usage + value <= hourly_limit and monthly_usage + value <= monthly_limit
|
||||||
|
|
||||||
|
elif limit_type == "storage_file":
|
||||||
|
file_limit = tier_limits.get("storage_file_limit", 0)
|
||||||
|
file_count = usage.get("storage_file_count", 0)
|
||||||
|
|
||||||
|
return file_count + value <= file_limit
|
||||||
|
|
||||||
|
elif limit_type == "storage_size":
|
||||||
|
size_limit_bytes = tier_limits.get("storage_size_limit_gb", 0) * 1024 * 1024 * 1024
|
||||||
|
size_usage = usage.get("storage_size_bytes", 0)
|
||||||
|
|
||||||
|
return size_usage + value <= size_limit_bytes
|
||||||
|
|
||||||
|
elif limit_type == "graph":
|
||||||
|
graph_limit = tier_limits.get("graph_creation_limit", 0)
|
||||||
|
graph_count = usage.get("graph_count", 0)
|
||||||
|
|
||||||
|
return graph_count + value <= graph_limit
|
||||||
|
|
||||||
|
elif limit_type == "cache":
|
||||||
|
cache_limit = tier_limits.get("cache_creation_limit", 0)
|
||||||
|
cache_count = usage.get("cache_count", 0)
|
||||||
|
|
||||||
|
return cache_count + value <= cache_limit
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def record_usage(self, user_id: str, usage_type: str, increment: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Record usage for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
usage_type: Type of usage (query, ingest, storage_file, storage_size, etc.)
|
||||||
|
increment: Value to increment by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
# Skip usage recording for self-hosted mode
|
||||||
|
if self.settings.MODE == "self_hosted":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if user limits exist, create if they don't
|
||||||
|
user_data = await self.db.get_user_limits(user_id)
|
||||||
|
if not user_data:
|
||||||
|
logger.info(f"Creating user limits for user {user_id} during usage recording")
|
||||||
|
success = await self.db.create_user_limits(user_id, tier=AccountTier.FREE)
|
||||||
|
if not success:
|
||||||
|
logger.error(f"Failed to create user limits for user {user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return await self.db.update_usage(user_id, usage_type, increment)
|
||||||
|
|
||||||
|
async def generate_cloud_uri(
|
||||||
|
self, user_id: str, app_id: str, name: str, expiry_days: int = 30
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Generate a cloud URI for an app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
app_id: The app ID
|
||||||
|
name: App name for display purposes
|
||||||
|
expiry_days: Number of days until token expires
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URI string with embedded token, or None if failed
|
||||||
|
"""
|
||||||
|
# Get user limits to check app limit
|
||||||
|
user_limits = await self.db.get_user_limits(user_id)
|
||||||
|
|
||||||
|
# If user doesn't exist yet, create them
|
||||||
|
if not user_limits:
|
||||||
|
await self.create_user(user_id)
|
||||||
|
user_limits = await self.db.get_user_limits(user_id)
|
||||||
|
if not user_limits:
|
||||||
|
logger.error(f"Failed to create user limits for user {user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get tier limits to enforce app limit
|
||||||
|
tier = user_limits.get("tier", AccountTier.FREE)
|
||||||
|
tier_limits = get_tier_limits(tier, user_limits.get("custom_limits"))
|
||||||
|
app_limit = tier_limits.get("app_limit", 1) # Default to 1 if not specified
|
||||||
|
|
||||||
|
current_apps = user_limits.get("app_ids", [])
|
||||||
|
|
||||||
|
# Skip the limit check if app is already registered
|
||||||
|
if app_id not in current_apps:
|
||||||
|
# Check if user has reached app limit
|
||||||
|
if len(current_apps) >= app_limit:
|
||||||
|
logger.info(f"User {user_id} has reached app limit ({app_limit}) for tier {tier}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Register the app
|
||||||
|
success = await self.register_app(user_id, app_id)
|
||||||
|
if not success:
|
||||||
|
logger.info(f"Failed to register app {app_id} for user {user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create token payload
|
||||||
|
payload = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"name": name,
|
||||||
|
"permissions": ["read", "write"],
|
||||||
|
"exp": int((datetime.now(UTC) + timedelta(days=expiry_days)).timestamp()),
|
||||||
|
"type": "developer",
|
||||||
|
"entity_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate token
|
||||||
|
token = jwt.encode(
|
||||||
|
payload, self.settings.JWT_SECRET_KEY, algorithm=self.settings.JWT_ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate URI with API domain
|
||||||
|
api_domain = getattr(self.settings, "API_DOMAIN", "api.databridge.ai")
|
||||||
|
uri = f"databridge://{name}:{token}@{api_domain}/{app_id}"
|
||||||
|
|
||||||
|
return uri
|
@ -100,6 +100,7 @@ batch_size = 4096
|
|||||||
|
|
||||||
[databridge]
|
[databridge]
|
||||||
enable_colpali = true
|
enable_colpali = true
|
||||||
|
mode = "self_hosted" # "cloud" or "self_hosted"
|
||||||
|
|
||||||
[graph]
|
[graph]
|
||||||
provider = "ollama"
|
provider = "ollama"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user