Spaces:
Running
Running
| from datetime import datetime, timezone | |
| from typing import Dict, Optional, Tuple | |
| from utils.logger import logger | |
| from utils.config import config, EnvMode | |
| # Define subscription tiers and their monthly limits (in minutes) | |
| SUBSCRIPTION_TIERS = { | |
| 'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8}, | |
| 'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, | |
| 'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} | |
| } | |
| async def get_account_subscription(client, account_id: str) -> Optional[Dict]: | |
| """Get the current subscription for an account.""" | |
| result = await client.schema('basejump').from_('billing_subscriptions') \ | |
| .select('*') \ | |
| .eq('account_id', account_id) \ | |
| .eq('status', 'active') \ | |
| .order('created', desc=True) \ | |
| .limit(1) \ | |
| .execute() | |
| if result.data and len(result.data) > 0: | |
| return result.data[0] | |
| return None | |
| async def calculate_monthly_usage(client, account_id: str) -> float: | |
| """Calculate total agent run minutes for the current month for an account.""" | |
| # Get start of current month in UTC | |
| now = datetime.now(timezone.utc) | |
| start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc) | |
| # First get all threads for this account | |
| threads_result = await client.table('threads') \ | |
| .select('thread_id') \ | |
| .eq('account_id', account_id) \ | |
| .execute() | |
| if not threads_result.data: | |
| return 0.0 | |
| thread_ids = [t['thread_id'] for t in threads_result.data] | |
| # Then get all agent runs for these threads in current month | |
| runs_result = await client.table('agent_runs') \ | |
| .select('started_at, completed_at') \ | |
| .in_('thread_id', thread_ids) \ | |
| .gte('started_at', start_of_month.isoformat()) \ | |
| .execute() | |
| if not runs_result.data: | |
| return 0.0 | |
| # Calculate total minutes | |
| total_seconds = 0 | |
| now_ts = now.timestamp() | |
| for run in runs_result.data: | |
| start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp() | |
| if run['completed_at']: | |
| end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp() | |
| else: | |
| # For running jobs, use current time | |
| end_time = now_ts | |
| total_seconds += (end_time - start_time) | |
| return total_seconds / 60 # Convert to minutes | |
| async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]: | |
| """ | |
| Check if an account can run agents based on their subscription and usage. | |
| Returns: | |
| Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info) | |
| """ | |
| if config.ENV_MODE == EnvMode.LOCAL: | |
| logger.info("Running in local development mode - billing checks are disabled") | |
| return True, "Local development mode - billing disabled", { | |
| "price_id": "local_dev", | |
| "plan_name": "Local Development", | |
| "minutes_limit": "no limit" | |
| } | |
| # For staging/production, check subscription status | |
| # Get current subscription | |
| subscription = await get_account_subscription(client, account_id) | |
| # If no subscription, they can use free tier | |
| if not subscription: | |
| subscription = { | |
| 'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier | |
| 'plan_name': 'free' | |
| } | |
| # if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC': | |
| # return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription | |
| # Get tier info | |
| tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id']) | |
| if not tier_info: | |
| return False, "Invalid subscription tier", subscription | |
| # Calculate current month's usage | |
| current_usage = await calculate_monthly_usage(client, account_id) | |
| # Check if within limits | |
| if current_usage >= tier_info['minutes']: | |
| return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription | |
| return True, "OK", subscription | |
| # Helper function to get account ID from thread | |
| async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]: | |
| """Get the account ID associated with a thread.""" | |
| result = await client.table('threads') \ | |
| .select('account_id') \ | |
| .eq('thread_id', thread_id) \ | |
| .limit(1) \ | |
| .execute() | |
| if result.data and len(result.data) > 0: | |
| return result.data[0]['account_id'] | |
| return None | |