Spaces:
Running
Running
| from fastapi import HTTPException, Request, Depends | |
| from typing import Optional, List, Dict, Any | |
| import jwt | |
| from jwt.exceptions import PyJWTError | |
| from utils.logger import logger | |
| # This function extracts the user ID from Supabase JWT | |
| async def get_current_user_id(request: Request) -> str: | |
| """ | |
| Extract and verify the user ID from the JWT in the Authorization header. | |
| This function is used as a dependency in FastAPI routes to ensure the user | |
| is authenticated and to provide the user ID for authorization checks. | |
| Args: | |
| request: The FastAPI request object | |
| Returns: | |
| str: The user ID extracted from the JWT | |
| Raises: | |
| HTTPException: If no valid token is found or if the token is invalid | |
| """ | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="No valid authentication credentials found", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| token = auth_header.split(' ')[1] | |
| try: | |
| # For Supabase JWT, we just need to decode and extract the user ID | |
| # The actual validation is handled by Supabase's RLS | |
| payload = jwt.decode(token, options={"verify_signature": False}) | |
| # Supabase stores the user ID in the 'sub' claim | |
| user_id = payload.get('sub') | |
| if not user_id: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid token payload", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| return user_id | |
| except PyJWTError: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid token", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| async def get_user_id_from_stream_auth( | |
| request: Request, | |
| token: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Extract and verify the user ID from either the Authorization header or query parameter token. | |
| This function is specifically designed for streaming endpoints that need to support both | |
| header-based and query parameter-based authentication (for EventSource compatibility). | |
| Args: | |
| request: The FastAPI request object | |
| token: Optional token from query parameters | |
| Returns: | |
| str: The user ID extracted from the JWT | |
| Raises: | |
| HTTPException: If no valid token is found or if the token is invalid | |
| """ | |
| # Try to get user_id from token in query param (for EventSource which can't set headers) | |
| if token: | |
| try: | |
| # For Supabase JWT, we just need to decode and extract the user ID | |
| payload = jwt.decode(token, options={"verify_signature": False}) | |
| user_id = payload.get('sub') | |
| if user_id: | |
| return user_id | |
| except Exception: | |
| pass | |
| # If no valid token in query param, try to get it from the Authorization header | |
| auth_header = request.headers.get('Authorization') | |
| if auth_header and auth_header.startswith('Bearer '): | |
| try: | |
| # Extract token from header | |
| header_token = auth_header.split(' ')[1] | |
| payload = jwt.decode(header_token, options={"verify_signature": False}) | |
| user_id = payload.get('sub') | |
| if user_id: | |
| return user_id | |
| except Exception: | |
| pass | |
| # If we still don't have a user_id, return authentication error | |
| raise HTTPException( | |
| status_code=401, | |
| detail="No valid authentication credentials found", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| async def verify_thread_access(client, thread_id: str, user_id: str): | |
| """ | |
| Verify that a user has access to a specific thread based on account membership. | |
| Args: | |
| client: The Supabase client | |
| thread_id: The thread ID to check access for | |
| user_id: The user ID to check permissions for | |
| Returns: | |
| bool: True if the user has access | |
| Raises: | |
| HTTPException: If the user doesn't have access to the thread | |
| """ | |
| # Query the thread to get account information | |
| thread_result = await client.table('threads').select('*,project_id').eq('thread_id', thread_id).execute() | |
| if not thread_result.data or len(thread_result.data) == 0: | |
| raise HTTPException(status_code=404, detail="Thread not found") | |
| thread_data = thread_result.data[0] | |
| # Check if project is public | |
| project_id = thread_data.get('project_id') | |
| if project_id: | |
| project_result = await client.table('projects').select('is_public').eq('project_id', project_id).execute() | |
| if project_result.data and len(project_result.data) > 0: | |
| if project_result.data[0].get('is_public'): | |
| return True | |
| account_id = thread_data.get('account_id') | |
| # When using service role, we need to manually check account membership instead of using current_user_account_role | |
| if account_id: | |
| account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() | |
| if account_user_result.data and len(account_user_result.data) > 0: | |
| return True | |
| raise HTTPException(status_code=403, detail="Not authorized to access this thread") | |
| async def get_optional_user_id(request: Request) -> Optional[str]: | |
| """ | |
| Extract the user ID from the JWT in the Authorization header if present, | |
| but don't require authentication. Returns None if no valid token is found. | |
| This function is used for endpoints that support both authenticated and | |
| unauthenticated access (like public projects). | |
| Args: | |
| request: The FastAPI request object | |
| Returns: | |
| Optional[str]: The user ID extracted from the JWT, or None if no valid token | |
| """ | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return None | |
| token = auth_header.split(' ')[1] | |
| try: | |
| # For Supabase JWT, we just need to decode and extract the user ID | |
| payload = jwt.decode(token, options={"verify_signature": False}) | |
| # Supabase stores the user ID in the 'sub' claim | |
| user_id = payload.get('sub') | |
| return user_id | |
| except PyJWTError: | |
| return None | |