Spaces:
Running
Running
| """ | |
| Multi-Modal Knowledge Distillation Web Application | |
| A FastAPI-based web application for creating new AI models through knowledge distillation | |
| from multiple pre-trained models across different modalities. | |
| """ | |
| import os | |
| import asyncio | |
| import logging | |
| import uuid | |
| from typing import List, Dict, Any, Optional, Union | |
| from pathlib import Path | |
| import json | |
| import shutil | |
| from datetime import datetime | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks, WebSocket, WebSocketDisconnect, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.responses import HTMLResponse, FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| from src.model_loader import ModelLoader | |
| from src.distillation import KnowledgeDistillationTrainer | |
| from src.utils import setup_logging, validate_file, cleanup_temp_files, get_system_info | |
| # Import new core components | |
| from src.core.memory_manager import AdvancedMemoryManager | |
| from src.core.chunk_loader import AdvancedChunkLoader | |
| from src.core.cpu_optimizer import CPUOptimizer | |
| from src.core.token_manager import TokenManager | |
| # Import medical components | |
| from src.medical.medical_datasets import MedicalDatasetManager | |
| from src.medical.dicom_handler import DicomHandler | |
| from src.medical.medical_preprocessing import MedicalPreprocessor | |
| # Import database components | |
| from database.database import DatabaseManager | |
| from src.database_manager import DatabaseManager as PlatformDatabaseManager | |
| from src.models_manager import ModelsManager | |
| # Setup logging with error handling | |
| try: | |
| setup_logging() | |
| logger = logging.getLogger(__name__) | |
| except Exception as e: | |
| # Fallback to basic logging if setup fails | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.warning(f"Failed to setup advanced logging: {e}") | |
| # Custom JSON encoder for handling Path objects and other non-serializable types | |
| class CustomJSONEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, Path): | |
| return str(obj) | |
| elif hasattr(obj, '__dict__'): | |
| return obj.__dict__ | |
| elif hasattr(obj, 'tolist'): # For numpy arrays | |
| return obj.tolist() | |
| elif hasattr(obj, 'detach'): # For PyTorch tensors | |
| return obj.detach().cpu().numpy().tolist() | |
| return super().default(obj) | |
| def safe_json_serialize(data): | |
| """Safely serialize data to JSON, handling non-serializable objects""" | |
| try: | |
| return json.loads(json.dumps(data, cls=CustomJSONEncoder)) | |
| except Exception as e: | |
| logger.warning(f"Failed to serialize data: {e}") | |
| # Return a safe version | |
| if isinstance(data, dict): | |
| safe_data = {} | |
| for k, v in data.items(): | |
| try: | |
| json.dumps(v, cls=CustomJSONEncoder) | |
| safe_data[k] = v | |
| except: | |
| safe_data[k] = str(v) | |
| return safe_data | |
| else: | |
| return str(data) | |
| def cleanup_training_session(session_id: str): | |
| """Clean up training session resources""" | |
| try: | |
| if session_id in training_sessions: | |
| session = training_sessions[session_id] | |
| # Clean up any temporary files | |
| model_path = session.get("model_path") | |
| if model_path and Path(model_path).exists(): | |
| try: | |
| shutil.rmtree(model_path) | |
| logger.info(f"Cleaned up model files for session {session_id}") | |
| except Exception as e: | |
| logger.warning(f"Failed to clean up model files: {e}") | |
| # Remove from active sessions | |
| del training_sessions[session_id] | |
| # Remove WebSocket connection if exists | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| logger.info(f"Cleaned up training session: {session_id}") | |
| except Exception as e: | |
| logger.error(f"Error cleaning up session {session_id}: {e}") | |
| def cleanup_old_sessions(): | |
| """Clean up old completed or failed sessions""" | |
| try: | |
| current_time = datetime.now().timestamp() | |
| sessions_to_remove = [] | |
| for session_id, session in training_sessions.items(): | |
| session_status = session.get("status", "unknown") | |
| end_time = session.get("end_time") | |
| # Remove sessions older than 1 hour if completed/failed | |
| if session_status in ["completed", "failed", "cancelled"] and end_time: | |
| if current_time - end_time > 3600: # 1 hour | |
| sessions_to_remove.append(session_id) | |
| for session_id in sessions_to_remove: | |
| cleanup_training_session(session_id) | |
| logger.info(f"Auto-cleaned old session: {session_id}") | |
| except Exception as e: | |
| logger.error(f"Error during automatic cleanup: {e}") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Multi-Modal Knowledge Distillation", | |
| description="Create new AI models through knowledge distillation from multiple pre-trained models", | |
| version="2.1.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Global variables for tracking training sessions | |
| training_sessions: Dict[str, Dict[str, Any]] = {} | |
| active_connections: Dict[str, WebSocket] = {} | |
| # Startup event to clean old sessions | |
| async def startup_event(): | |
| """Initialize application and clean up old sessions""" | |
| try: | |
| logger.info("Starting Multi-Modal Knowledge Distillation Platform") | |
| # Clean up any old sessions from previous runs | |
| cleanup_old_sessions() | |
| # Initialize core components | |
| logger.info("Initializing core components...") | |
| # Log system information | |
| system_info = get_system_info() | |
| logger.info(f"System Info: {system_info}") | |
| logger.info("Application startup completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error during startup: {e}") | |
| # Shutdown event to clean up resources | |
| async def shutdown_event(): | |
| """Clean up resources on shutdown""" | |
| try: | |
| logger.info("Shutting down application...") | |
| # Clean up all active sessions | |
| for session_id in list(training_sessions.keys()): | |
| cleanup_training_session(session_id) | |
| # Clean up temporary files | |
| cleanup_temp_files() | |
| logger.info("Application shutdown completed") | |
| except Exception as e: | |
| logger.error(f"Error during shutdown: {e}") | |
| # Pydantic models for API | |
| class TrainingConfig(BaseModel): | |
| session_id: str = Field(..., description="Unique session identifier") | |
| teacher_models: List[Union[str, Dict[str, Any]]] = Field(..., description="List of teacher model paths/URLs or model configs") | |
| student_config: Dict[str, Any] = Field(default_factory=dict, description="Student model configuration") | |
| training_params: Dict[str, Any] = Field(default_factory=dict, description="Training parameters") | |
| distillation_strategy: str = Field(default="ensemble", description="Distillation strategy") | |
| hf_token: Optional[str] = Field(default=None, description="Hugging Face token") | |
| trust_remote_code: bool = Field(default=False, description="Trust remote code execution") | |
| existing_student_model: Optional[str] = Field(default=None, description="Path to existing trained student model for retraining") | |
| incremental_training: bool = Field(default=False, description="Whether this is incremental training") | |
| class TrainingStatus(BaseModel): | |
| session_id: str | |
| status: str | |
| progress: float | |
| current_step: int | |
| total_steps: int | |
| loss: Optional[float] = None | |
| eta: Optional[str] = None | |
| message: str = "" | |
| class ModelInfo(BaseModel): | |
| name: str | |
| size: int | |
| format: str | |
| modality: str | |
| architecture: Optional[str] = None | |
| class DatabaseInfo(BaseModel): | |
| name: str | |
| name_ar: Optional[str] = "" | |
| dataset_id: str | |
| category: str = "general" | |
| description: str = "" | |
| description_ar: Optional[str] = "" | |
| size: Optional[str] = "Unknown" | |
| language: Optional[str] = "Unknown" | |
| modality: str = "text" | |
| license: Optional[str] = "Unknown" | |
| class DatabaseSearchRequest(BaseModel): | |
| query: str | |
| limit: int = 20 | |
| category: Optional[str] = None | |
| class DatabaseSelectionRequest(BaseModel): | |
| database_ids: List[str] | |
| class ModelSearchRequest(BaseModel): | |
| query: str | |
| limit: int = 20 | |
| model_type: Optional[str] = None | |
| class ModelSelectionRequest(BaseModel): | |
| teacher_models: List[str] = [] | |
| student_model: Optional[str] = None | |
| # Initialize components | |
| model_loader = ModelLoader() | |
| distillation_trainer = KnowledgeDistillationTrainer() | |
| # Initialize new advanced components | |
| memory_manager = AdvancedMemoryManager(max_memory_gb=14.0) # 14GB for 16GB systems | |
| chunk_loader = AdvancedChunkLoader(memory_manager) | |
| cpu_optimizer = CPUOptimizer(memory_manager) | |
| token_manager = TokenManager() | |
| # Initialize database manager | |
| platform_db_manager = PlatformDatabaseManager() | |
| # Initialize models manager | |
| models_manager = ModelsManager() | |
| database_manager = DatabaseManager() | |
| # Initialize medical components | |
| medical_dataset_manager = MedicalDatasetManager(memory_manager) | |
| dicom_handler = DicomHandler(memory_limit_mb=1000.0) | |
| medical_preprocessor = MedicalPreprocessor() | |
| async def startup_event(): | |
| """Initialize application on startup""" | |
| logger.info("Starting Multi-Modal Knowledge Distillation application") | |
| # Create necessary directories with error handling | |
| for directory in ["uploads", "models", "temp", "logs"]: | |
| try: | |
| Path(directory).mkdir(exist_ok=True) | |
| logger.info(f"Created/verified directory: {directory}") | |
| except PermissionError: | |
| logger.warning(f"Cannot create directory {directory}, using temp directory") | |
| except Exception as e: | |
| logger.warning(f"Error creating directory {directory}: {e}") | |
| # Log system information | |
| try: | |
| system_info = get_system_info() | |
| logger.info(f"System info: {system_info}") | |
| except Exception as e: | |
| logger.warning(f"Could not get system info: {e}") | |
| async def shutdown_event(): | |
| """Cleanup on application shutdown""" | |
| logger.info("Shutting down application") | |
| cleanup_temp_files() | |
| async def read_root(): | |
| """Serve the main web interface""" | |
| return templates.TemplateResponse("index.html", {"request": {}}) | |
| async def health_check(): | |
| """Health check endpoint for Docker and monitoring""" | |
| try: | |
| # Get system information | |
| memory_info = memory_manager.get_memory_info() | |
| # Check if default token is available | |
| default_token = token_manager.get_token() | |
| return { | |
| "status": "healthy", | |
| "version": "2.0.0", | |
| "timestamp": datetime.now().isoformat(), | |
| "memory": { | |
| "usage_percent": memory_info.get("process_memory_percent", 0), | |
| "available_gb": memory_info.get("system_memory_available_gb", 0), | |
| "status": memory_manager.check_memory_status() | |
| }, | |
| "tokens": { | |
| "default_available": bool(default_token), | |
| "total_tokens": len(token_manager.list_tokens()) | |
| }, | |
| "features": { | |
| "memory_management": True, | |
| "chunk_loading": True, | |
| "cpu_optimization": True, | |
| "medical_datasets": True, | |
| "token_management": True | |
| }, | |
| "system_info": get_system_info() | |
| } | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "timestamp": datetime.now().isoformat(), | |
| "version": "2.0.0" | |
| } | |
| async def test_token(): | |
| """Test if HF token is working""" | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| if not hf_token: | |
| return { | |
| "token_available": False, | |
| "message": "No HF token found in environment variables" | |
| } | |
| try: | |
| # Test token by trying to access a gated model's config | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained("google/gemma-2b", token=hf_token) | |
| return { | |
| "token_available": True, | |
| "token_valid": True, | |
| "message": "Token is working correctly" | |
| } | |
| except Exception as e: | |
| return { | |
| "token_available": True, | |
| "token_valid": False, | |
| "message": f"Token validation failed: {str(e)}" | |
| } | |
| async def test_model_loading(request: Dict[str, Any]): | |
| """Test loading a specific model""" | |
| try: | |
| model_path = request.get('model_path') | |
| trust_remote_code = request.get('trust_remote_code', False) | |
| if not model_path: | |
| return {"success": False, "error": "model_path is required"} | |
| # Get appropriate token based on access type | |
| access_type = request.get('access_type', 'read') | |
| hf_token = request.get('token') | |
| if not hf_token or hf_token == 'auto': | |
| # Get appropriate token for the access type | |
| hf_token = token_manager.get_token_for_task(access_type) | |
| if hf_token: | |
| logger.info(f"Using {access_type} token for model testing") | |
| else: | |
| logger.warning(f"No suitable token found for {access_type} access") | |
| # Fallback to environment variables | |
| hf_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| # Test model loading | |
| model_info = await model_loader.get_model_info(model_path) | |
| return { | |
| "success": True, | |
| "model_info": model_info, | |
| "message": f"Model {model_path} can be loaded" | |
| } | |
| except Exception as e: | |
| error_msg = str(e) | |
| suggestions = [] | |
| if 'trust_remote_code' in error_msg.lower(): | |
| suggestions.append("فعّل 'Trust Remote Code' للنماذج التي تتطلب كود مخصص") | |
| elif 'gated' in error_msg.lower(): | |
| suggestions.append("النموذج يتطلب إذن وصول خاص - استخدم رمز مخصص") | |
| elif 'siglip' in error_msg.lower(): | |
| suggestions.append("جرب تفعيل 'Trust Remote Code' لنماذج SigLIP") | |
| elif '401' in error_msg or 'authentication' in error_msg.lower(): | |
| suggestions.append("تحقق من رمز Hugging Face الخاص بك") | |
| suggestions.append("تأكد من أن الرمز له صلاحية الوصول لهذا النموذج") | |
| elif '404' in error_msg or 'not found' in error_msg.lower(): | |
| suggestions.append("تحقق من اسم مستودع النموذج") | |
| suggestions.append("تأكد من وجود النموذج على Hugging Face") | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "suggestions": suggestions | |
| } | |
| async def upload_model( | |
| background_tasks: BackgroundTasks, | |
| files: List[UploadFile] = File(...), | |
| model_names: List[str] = Form(...) | |
| ): | |
| """Upload model files""" | |
| try: | |
| uploaded_models = [] | |
| for file, name in zip(files, model_names): | |
| # Validate file | |
| validation_result = validate_file(file) | |
| if not validation_result["valid"]: | |
| raise HTTPException(status_code=400, detail=validation_result["error"]) | |
| # Generate unique filename | |
| file_id = str(uuid.uuid4()) | |
| file_extension = Path(file.filename).suffix | |
| safe_filename = f"{file_id}{file_extension}" | |
| file_path = Path("uploads") / safe_filename | |
| # Save file | |
| with open(file_path, "wb") as buffer: | |
| content = await file.read() | |
| buffer.write(content) | |
| # Get model info | |
| model_info = await model_loader.get_model_info(str(file_path)) | |
| uploaded_models.append({ | |
| "id": file_id, | |
| "name": name, | |
| "filename": file.filename, | |
| "path": str(file_path), | |
| "size": len(content), | |
| "info": model_info | |
| }) | |
| logger.info(f"Uploaded model: {name} ({file.filename})") | |
| # Schedule cleanup of old files | |
| background_tasks.add_task(cleanup_temp_files, max_age_hours=24) | |
| return { | |
| "success": True, | |
| "models": uploaded_models, | |
| "message": f"Successfully uploaded {len(uploaded_models)} model(s)" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error uploading models: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def start_training( | |
| background_tasks: BackgroundTasks, | |
| config: TrainingConfig | |
| ): | |
| """Start knowledge distillation training""" | |
| try: | |
| session_id = config.session_id | |
| # Handle existing sessions | |
| if session_id in training_sessions: | |
| existing_session = training_sessions[session_id] | |
| existing_status = existing_session.get("status", "unknown") | |
| # Allow restarting failed or completed sessions | |
| if existing_status in ["failed", "completed", "cancelled"]: | |
| logger.info(f"Restarting session {session_id} (previous status: {existing_status})") | |
| # Clean up old session | |
| cleanup_training_session(session_id) | |
| elif existing_status in ["running", "initializing"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Training session already running (status: {existing_status})" | |
| ) | |
| else: | |
| # Unknown status, clean up and restart | |
| logger.warning(f"Unknown session status {existing_status}, cleaning up") | |
| cleanup_training_session(session_id) | |
| # Set HF token from environment if available | |
| hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_TOKEN') | |
| if hf_token: | |
| os.environ['HF_TOKEN'] = hf_token | |
| logger.info("Using Hugging Face token from environment") | |
| # Check for large models and warn | |
| large_models = [] | |
| for model_info in config.teacher_models: | |
| model_path = model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| if any(size_indicator in model_path.lower() for size_indicator in ['27b', '70b', '13b']): | |
| large_models.append(model_path) | |
| # Initialize training session with safe config serialization | |
| safe_config = safe_json_serialize(config.dict()) | |
| training_sessions[session_id] = { | |
| "status": "initializing", | |
| "progress": 0.0, | |
| "current_step": 0, | |
| "total_steps": config.training_params.get("max_steps", 1000), | |
| "config": safe_config, | |
| "start_time": None, | |
| "end_time": None, | |
| "model_path": None, | |
| "logs": [], | |
| "large_models": large_models, | |
| "message": "Initializing training session..." + ( | |
| f" (Large models detected: {', '.join(large_models)})" if large_models else "" | |
| ) | |
| } | |
| # Start training in background | |
| background_tasks.add_task(run_training, session_id, config) | |
| logger.info(f"Started training session: {session_id}") | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "message": "Training started successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error starting training: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def run_training(session_id: str, config: TrainingConfig): | |
| """Run knowledge distillation training in background""" | |
| try: | |
| session = training_sessions[session_id] | |
| session["status"] = "running" | |
| session["start_time"] = asyncio.get_event_loop().time() | |
| # Set timeout for the entire operation (30 minutes) | |
| timeout_seconds = 30 * 60 | |
| # Set HF token for this session - prioritize config token | |
| config_token = getattr(config, 'hf_token', None) | |
| env_token = ( | |
| os.getenv('HF_TOKEN') or | |
| os.getenv('HUGGINGFACE_TOKEN') or | |
| os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| ) | |
| hf_token = config_token or env_token | |
| if hf_token: | |
| logger.info(f"Using Hugging Face token from {'config' if config_token else 'environment'}") | |
| # Set token in environment for this session | |
| os.environ['HF_TOKEN'] = hf_token | |
| else: | |
| logger.warning("No Hugging Face token found - private models may fail") | |
| # Handle existing student model for incremental training | |
| existing_student = None | |
| if config.existing_student_model and config.incremental_training: | |
| try: | |
| await update_training_status(session_id, "loading_student", 0.05, "Loading existing student model...") | |
| # Determine student source and load accordingly | |
| student_source = getattr(config, 'student_source', 'local') | |
| student_path = config.existing_student_model | |
| if student_source == 'huggingface' or ('/' in student_path and not Path(student_path).exists()): | |
| logger.info(f"Loading student model from Hugging Face: {student_path}") | |
| existing_student = await model_loader.load_trained_student(student_path) | |
| elif student_source == 'space': | |
| logger.info(f"Loading student model from Hugging Face Space: {student_path}") | |
| # For spaces, we'll try to load from the space's models directory | |
| space_model_path = f"spaces/{student_path}/models" | |
| existing_student = await model_loader.load_trained_student_from_space(student_path) | |
| else: | |
| logger.info(f"Loading student model from local path: {student_path}") | |
| existing_student = await model_loader.load_trained_student(student_path) | |
| logger.info(f"Successfully loaded existing student model: {existing_student.get('type', 'unknown')}") | |
| # Merge original teachers with new teachers | |
| original_teachers = existing_student.get('original_teachers', []) | |
| new_teachers = [ | |
| model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| for model_info in config.teacher_models | |
| ] | |
| # Combine teachers (avoid duplicates) | |
| all_teachers = original_teachers.copy() | |
| for teacher in new_teachers: | |
| if teacher not in all_teachers: | |
| all_teachers.append(teacher) | |
| logger.info(f"Incremental training: Original teachers: {original_teachers}") | |
| logger.info(f"Incremental training: New teachers: {new_teachers}") | |
| logger.info(f"Incremental training: All teachers: {all_teachers}") | |
| # Update config with all teachers | |
| config.teacher_models = all_teachers | |
| except Exception as e: | |
| logger.error(f"Error loading existing student model: {e}") | |
| await update_training_status(session_id, "failed", session.get("progress", 0), f"Failed to load existing student: {str(e)}") | |
| return | |
| # Load teacher models | |
| await update_training_status(session_id, "loading_models", 0.1, "Loading teacher models...") | |
| teacher_models = [] | |
| trust_remote_code = config.training_params.get('trust_remote_code', False) | |
| total_models = len(config.teacher_models) | |
| for i, model_info in enumerate(config.teacher_models): | |
| try: | |
| # Handle both old format (string) and new format (dict) | |
| if isinstance(model_info, str): | |
| model_path = model_info | |
| model_token = hf_token | |
| model_trust_code = trust_remote_code | |
| else: | |
| model_path = model_info.get('path', model_info) | |
| model_token = model_info.get('token') or hf_token | |
| model_trust_code = model_info.get('trust_remote_code', trust_remote_code) | |
| # Update progress | |
| progress = 0.1 + (i * 0.3 / total_models) # 0.1 to 0.4 | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loading model {i+1}/{total_models}: {model_path}..." | |
| ) | |
| logger.info(f"Loading model {model_path} with trust_remote_code={model_trust_code}") | |
| # Special handling for known problematic models | |
| if model_path == 'Wan-AI/Wan2.2-TI2V-5B': | |
| logger.info(f"Detected ti2v model {model_path}, forcing trust_remote_code=True") | |
| model_trust_code = True | |
| elif model_path == 'deepseek-ai/DeepSeek-V3.1-Base': | |
| logger.warning(f"Skipping {model_path}: Requires GPU with FP8 quantization support") | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Skipping {model_path}: Requires GPU with FP8 quantization" | |
| ) | |
| continue | |
| model = await model_loader.load_model( | |
| model_path, | |
| token=model_token, | |
| trust_remote_code=model_trust_code | |
| ) | |
| teacher_models.append(model) | |
| logger.info(f"Successfully loaded model: {model_path}") | |
| # Update progress after successful load | |
| progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loaded {i+1}/{total_models} models successfully" | |
| ) | |
| except Exception as e: | |
| error_msg = f"Failed to load model {model_path}: {str(e)}" | |
| logger.error(error_msg) | |
| # Provide helpful suggestions based on the error | |
| suggestions = [] | |
| error_str = str(e).lower() | |
| # Check if we should retry with trust_remote_code=True | |
| if not model_trust_code and ('ti2v' in error_str or 'does not recognize this architecture' in error_str): | |
| try: | |
| logger.info(f"Retrying {model_path} with trust_remote_code=True") | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Retrying {model_path} with trust_remote_code=True..." | |
| ) | |
| model = await model_loader.load_model( | |
| model_path, | |
| token=model_token, | |
| trust_remote_code=True | |
| ) | |
| teacher_models.append(model) | |
| logger.info(f"Successfully loaded model on retry: {model_path}") | |
| # Update progress after successful retry | |
| progress = 0.1 + ((i + 1) * 0.3 / total_models) | |
| await update_training_status( | |
| session_id, | |
| "loading_models", | |
| progress, | |
| f"Loaded {i+1}/{total_models} models successfully (retry)" | |
| ) | |
| continue | |
| except Exception as retry_e: | |
| logger.error(f"Retry also failed for {model_path}: {str(retry_e)}") | |
| error_msg = f"Failed even with trust_remote_code=True: {str(retry_e)}" | |
| if 'trust_remote_code' in error_str: | |
| suggestions.append("Try enabling 'Trust Remote Code' option") | |
| elif 'gated' in error_str or 'access' in error_str: | |
| suggestions.append("This model requires access permission and a valid HF token") | |
| elif 'siglip' in error_str or 'unknown' in error_str: | |
| suggestions.append("This model may require special loading. Try enabling 'Trust Remote Code'") | |
| elif 'connection' in error_str or 'network' in error_str: | |
| suggestions.append("Check your internet connection") | |
| elif 'ti2v' in error_str: | |
| suggestions.append("This ti2v model requires trust_remote_code=True") | |
| if suggestions: | |
| error_msg += f". Suggestions: {'; '.join(suggestions)}" | |
| await update_training_status(session_id, "failed", session.get("progress", 0), error_msg) | |
| return | |
| # Initialize student model | |
| await update_training_status(session_id, "initializing_student", 0.2, "Initializing student model...") | |
| student_model = await distillation_trainer.create_student_model( | |
| teacher_models, config.student_config | |
| ) | |
| # Run distillation training | |
| await update_training_status(session_id, "training", 0.3, "Starting knowledge distillation...") | |
| async def progress_callback(step: int, total_steps: int, loss: float, metrics: Dict[str, Any]): | |
| progress = 0.3 + (step / total_steps) * 0.6 # 30% to 90% | |
| await update_training_status( | |
| session_id, "training", progress, | |
| f"Training step {step}/{total_steps}, Loss: {loss:.4f}", | |
| current_step=step, loss=loss | |
| ) | |
| trained_model = await distillation_trainer.train( | |
| student_model, teacher_models, config.training_params, progress_callback | |
| ) | |
| # Save trained model with metadata | |
| await update_training_status(session_id, "saving", 0.9, "Saving trained model...") | |
| # Create model directory with proper structure | |
| model_dir = Path("models") / f"distilled_model_{session_id}" | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = model_dir / "pytorch_model.safetensors" | |
| # Prepare training metadata for saving | |
| training_metadata = { | |
| 'session_id': session_id, | |
| 'teacher_models': [ | |
| model_info if isinstance(model_info, str) else model_info.get('path', '') | |
| for model_info in config.teacher_models | |
| ], | |
| 'strategy': config.distillation_strategy, | |
| 'training_params': config.training_params, | |
| 'incremental_training': config.incremental_training, | |
| 'existing_student_model': config.existing_student_model | |
| } | |
| await distillation_trainer.save_model(trained_model, str(model_path), training_metadata) | |
| # Complete training | |
| session["status"] = "completed" | |
| session["progress"] = 1.0 | |
| session["end_time"] = asyncio.get_event_loop().time() | |
| session["model_path"] = model_path | |
| session["training_metadata"] = training_metadata | |
| await update_training_status(session_id, "completed", 1.0, "Training completed successfully!") | |
| logger.info(f"Training session {session_id} completed successfully") | |
| except Exception as e: | |
| logger.error(f"Training session {session_id} failed: {str(e)}") | |
| session = training_sessions.get(session_id, {}) | |
| session["status"] = "failed" | |
| session["error"] = str(e) | |
| await update_training_status(session_id, "failed", session.get("progress", 0), f"Training failed: {str(e)}") | |
| async def update_training_status( | |
| session_id: str, | |
| status: str, | |
| progress: float, | |
| message: str, | |
| current_step: int = None, | |
| loss: float = None | |
| ): | |
| """Update training status and notify connected clients""" | |
| if session_id in training_sessions: | |
| session = training_sessions[session_id] | |
| session["status"] = status | |
| session["progress"] = progress | |
| session["message"] = message | |
| if current_step is not None: | |
| session["current_step"] = current_step | |
| if loss is not None: | |
| session["loss"] = loss | |
| # Calculate ETA | |
| if session.get("start_time") and progress > 0: | |
| elapsed = asyncio.get_event_loop().time() - session["start_time"] | |
| if progress < 1.0: | |
| eta_seconds = (elapsed / progress) * (1.0 - progress) | |
| eta = f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s" | |
| session["eta"] = eta | |
| # Notify WebSocket clients | |
| if session_id in active_connections: | |
| try: | |
| # Safely serialize session data | |
| safe_session_data = safe_json_serialize(session) | |
| await active_connections[session_id].send_json({ | |
| "type": "training_update", | |
| "data": safe_session_data | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Failed to send WebSocket update: {e}") | |
| # Remove disconnected client | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| async def get_training_progress(session_id: str): | |
| """Get training progress for a session""" | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| return TrainingStatus( | |
| session_id=session_id, | |
| status=session["status"], | |
| progress=session["progress"], | |
| current_step=session["current_step"], | |
| total_steps=session["total_steps"], | |
| loss=session.get("loss"), | |
| eta=session.get("eta"), | |
| message=session.get("message", "") | |
| ) | |
| async def download_model(session_id: str): | |
| """Download trained model""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| if session["status"] != "completed": | |
| raise HTTPException(status_code=400, detail="Training not completed") | |
| model_path = session.get("model_path") | |
| if not model_path: | |
| # Try to find model in models directory | |
| models_dir = Path("models") | |
| possible_paths = [ | |
| models_dir / f"distilled_model_{session_id}", | |
| models_dir / f"distilled_model_{session_id}.safetensors", | |
| models_dir / f"model_{session_id}", | |
| models_dir / f"student_model_{session_id}" | |
| ] | |
| for path in possible_paths: | |
| if path.exists(): | |
| model_path = str(path) | |
| break | |
| if not model_path or not Path(model_path).exists(): | |
| raise HTTPException(status_code=404, detail="Model file not found. The model may not have been saved properly.") | |
| # Create a zip file with all model files | |
| import zipfile | |
| import tempfile | |
| model_dir = Path(model_path) | |
| if model_dir.is_file(): | |
| # Single file | |
| return FileResponse( | |
| model_path, | |
| media_type="application/octet-stream", | |
| filename=f"distilled_model_{session_id}.safetensors" | |
| ) | |
| else: | |
| # Directory with multiple files | |
| temp_zip = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') | |
| with zipfile.ZipFile(temp_zip.name, 'w') as zipf: | |
| for file_path in model_dir.rglob('*'): | |
| if file_path.is_file(): | |
| zipf.write(file_path, file_path.relative_to(model_dir)) | |
| return FileResponse( | |
| temp_zip.name, | |
| media_type="application/zip", | |
| filename=f"distilled_model_{session_id}.zip" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {e}") | |
| raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") | |
| async def upload_to_huggingface( | |
| session_id: str, | |
| repo_name: str = Form(...), | |
| description: str = Form(""), | |
| private: bool = Form(False), | |
| hf_token: str = Form(...) | |
| ): | |
| """Upload trained model to Hugging Face Hub""" | |
| try: | |
| if session_id not in training_sessions: | |
| raise HTTPException(status_code=404, detail="Training session not found") | |
| session = training_sessions[session_id] | |
| if session["status"] != "completed": | |
| raise HTTPException(status_code=400, detail="Training not completed") | |
| model_path = session.get("model_path") | |
| if not model_path or not Path(model_path).exists(): | |
| raise HTTPException(status_code=404, detail="Model file not found") | |
| # Import huggingface_hub | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| except ImportError: | |
| raise HTTPException(status_code=500, detail="huggingface_hub not installed") | |
| # Initialize HF API | |
| api = HfApi(token=hf_token) | |
| # Validate repository name format | |
| if '/' not in repo_name: | |
| raise HTTPException(status_code=400, detail="Repository name must be in format 'username/model-name'") | |
| username, model_name = repo_name.split('/', 1) | |
| # Create repository with better error handling | |
| try: | |
| repo_url = create_repo( | |
| repo_id=repo_name, | |
| token=hf_token, | |
| private=private, | |
| exist_ok=True | |
| ) | |
| logger.info(f"Created/accessed repository: {repo_url}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "403" in error_msg or "Forbidden" in error_msg: | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"Permission denied. Please check: 1) Your token has 'Write' permissions, 2) You own the namespace '{username}', 3) The repository name is correct. Error: {error_msg}" | |
| ) | |
| elif "401" in error_msg or "Unauthorized" in error_msg: | |
| raise HTTPException( | |
| status_code=401, | |
| detail=f"Invalid token. Please check your Hugging Face token. Error: {error_msg}" | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Failed to create repository: {error_msg}") | |
| # Upload model files | |
| model_path_obj = Path(model_path) | |
| uploaded_files = [] | |
| # Determine the model directory | |
| if model_path_obj.is_file(): | |
| model_dir = model_path_obj.parent | |
| else: | |
| model_dir = model_path_obj | |
| # Upload all files in the model directory | |
| essential_files = [ | |
| 'pytorch_model.safetensors', 'config.json', 'model.py', | |
| 'training_history.json', 'README.md' | |
| ] | |
| # Upload essential files first | |
| for file_name in essential_files: | |
| file_path = model_dir / file_name | |
| if file_path.exists(): | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=file_name, | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append(file_name) | |
| logger.info(f"Uploaded {file_name}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload {file_name}: {e}") | |
| # Upload any additional files | |
| for file_path in model_dir.rglob('*'): | |
| if file_path.is_file() and file_path.name not in essential_files: | |
| try: | |
| relative_path = file_path.relative_to(model_dir) | |
| api.upload_file( | |
| path_or_fileobj=str(file_path), | |
| path_in_repo=str(relative_path), | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append(str(relative_path)) | |
| logger.info(f"Uploaded additional file: {relative_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload {relative_path}: {e}") | |
| # Create README.md | |
| config_info = session.get("config", {}) | |
| teacher_models_raw = config_info.get("teacher_models", []) | |
| # Extract model paths from teacher_models (handle both string and dict formats) | |
| teacher_models = [] | |
| for model in teacher_models_raw: | |
| if isinstance(model, str): | |
| teacher_models.append(model) | |
| elif isinstance(model, dict): | |
| teacher_models.append(model.get('path', str(model))) | |
| else: | |
| teacher_models.append(str(model)) | |
| readme_content = f"""--- | |
| license: apache-2.0 | |
| tags: | |
| - knowledge-distillation | |
| - pytorch | |
| - transformers | |
| base_model: {teacher_models[0] if teacher_models else 'unknown'} | |
| --- | |
| # {repo_name} | |
| This model was created using knowledge distillation from the following teacher model(s): | |
| {chr(10).join([f"- {model}" for model in teacher_models])} | |
| ## Model Description | |
| {description if description else 'A distilled model created using multi-modal knowledge distillation.'} | |
| ## Training Details | |
| - **Teacher Models**: {', '.join(teacher_models)} | |
| - **Distillation Strategy**: {config_info.get('distillation_strategy', 'ensemble')} | |
| - **Training Steps**: {config_info.get('training_params', {}).get('max_steps', 'unknown')} | |
| - **Learning Rate**: {config_info.get('training_params', {}).get('learning_rate', 'unknown')} | |
| ## Usage | |
| ```python | |
| from transformers import AutoModel, AutoTokenizer | |
| model = AutoModel.from_pretrained("{repo_name}") | |
| tokenizer = AutoTokenizer.from_pretrained("{teacher_models[0] if teacher_models else 'bert-base-uncased'}") | |
| ``` | |
| ## Created with | |
| This model was created using the Multi-Modal Knowledge Distillation platform. | |
| """ | |
| # Upload README | |
| api.upload_file( | |
| path_or_fileobj=readme_content.encode(), | |
| path_in_repo="README.md", | |
| repo_id=repo_name, | |
| token=hf_token | |
| ) | |
| uploaded_files.append("README.md") | |
| return { | |
| "success": True, | |
| "repo_url": f"https://huggingface.co/{repo_name}", | |
| "uploaded_files": uploaded_files, | |
| "message": f"Model successfully uploaded to {repo_name}" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error uploading to Hugging Face: {e}") | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| async def validate_repo_name(request: Dict[str, Any]): | |
| """Validate repository name and check permissions""" | |
| try: | |
| repo_name = request.get('repo_name', '').strip() | |
| hf_token = request.get('hf_token', '').strip() | |
| if not repo_name or not hf_token: | |
| return {"valid": False, "error": "Repository name and token are required"} | |
| if '/' not in repo_name: | |
| return {"valid": False, "error": "Repository name must be in format 'username/model-name'"} | |
| username, model_name = repo_name.split('/', 1) | |
| # Check if username matches token owner | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token) | |
| # Try to get user info | |
| user_info = api.whoami() | |
| token_username = user_info.get('name', '') | |
| if username != token_username: | |
| return { | |
| "valid": False, | |
| "error": f"Username mismatch. Token belongs to '{token_username}' but trying to create repo under '{username}'. Use '{token_username}/{model_name}' instead.", | |
| "suggested_name": f"{token_username}/{model_name}" | |
| } | |
| return { | |
| "valid": True, | |
| "message": f"Repository name '{repo_name}' is valid for your account", | |
| "username": token_username | |
| } | |
| except Exception as e: | |
| return {"valid": False, "error": f"Token validation failed: {str(e)}"} | |
| except Exception as e: | |
| return {"valid": False, "error": f"Validation error: {str(e)}"} | |
| async def test_space(request: Dict[str, Any]): | |
| """Test if a Hugging Face Space exists and has trained models""" | |
| try: | |
| space_name = request.get('space_name', '').strip() | |
| hf_token = request.get('hf_token', '').strip() | |
| if not space_name: | |
| return {"success": False, "error": "Space name is required"} | |
| if '/' not in space_name: | |
| return {"success": False, "error": "Space name must be in format 'username/space-name'"} | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=hf_token if hf_token else None) | |
| # Check if the Space exists | |
| try: | |
| space_info = api.space_info(space_name) | |
| logger.info(f"Found Space: {space_name}") | |
| except Exception as e: | |
| return {"success": False, "error": f"Space not found or not accessible: {str(e)}"} | |
| # Try to list files in the Space to see if it has models | |
| try: | |
| files = api.list_repo_files(space_name, repo_type="space") | |
| model_files = [f for f in files if f.endswith(('.safetensors', '.bin', '.pt'))] | |
| # Check for models directory | |
| models_dir_files = [f for f in files if f.startswith('models/')] | |
| return { | |
| "success": True, | |
| "space_info": { | |
| "name": space_name, | |
| "model_files": model_files, | |
| "models_directory": len(models_dir_files) > 0, | |
| "total_files": len(files) | |
| }, | |
| "models": model_files, | |
| "message": f"Space {space_name} is accessible" | |
| } | |
| except Exception as e: | |
| # Space exists but we can't list files (might be private or no access) | |
| return { | |
| "success": True, | |
| "space_info": {"name": space_name}, | |
| "models": [], | |
| "message": f"Space {space_name} exists but file listing not available (might be private)" | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": f"Error accessing Hugging Face: {str(e)}"} | |
| except Exception as e: | |
| logger.error(f"Error testing Space: {e}") | |
| return {"success": False, "error": f"Test failed: {str(e)}"} | |
| async def list_trained_students(): | |
| """List available trained student models for retraining""" | |
| try: | |
| models_dir = Path("models") | |
| trained_students = [] | |
| if models_dir.exists(): | |
| for model_dir in models_dir.iterdir(): | |
| if model_dir.is_dir(): | |
| try: | |
| # Check if it's a trained student model | |
| config_files = list(model_dir.glob("*config.json")) | |
| history_files = list(model_dir.glob("*training_history.json")) | |
| if config_files: | |
| with open(config_files[0], 'r') as f: | |
| config = json.load(f) | |
| if config.get('is_student_model', False): | |
| history = {} | |
| if history_files: | |
| with open(history_files[0], 'r') as f: | |
| history = json.load(f) | |
| model_info = { | |
| "id": model_dir.name, | |
| "name": model_dir.name, | |
| "path": str(model_dir), | |
| "type": "trained_student", | |
| "created_at": config.get('created_at', 'unknown'), | |
| "architecture": config.get('architecture', 'unknown'), | |
| "modalities": config.get('modalities', ['text']), | |
| "can_be_retrained": config.get('can_be_retrained', True), | |
| "original_teachers": history.get('retraining_info', {}).get('original_teachers', []), | |
| "training_sessions": len(history.get('training_sessions', [])), | |
| "last_training": history.get('training_sessions', [{}])[-1].get('timestamp', 'unknown') if history.get('training_sessions') else 'unknown' | |
| } | |
| trained_students.append(model_info) | |
| except Exception as e: | |
| logger.warning(f"Error reading model {model_dir}: {e}") | |
| continue | |
| return {"trained_students": trained_students} | |
| except Exception as e: | |
| logger.error(f"Error listing trained students: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| """List available models""" | |
| models = [] | |
| # List uploaded models | |
| uploads_dir = Path("uploads") | |
| if uploads_dir.exists(): | |
| for file_path in uploads_dir.iterdir(): | |
| if file_path.is_file(): | |
| try: | |
| info = await model_loader.get_model_info(str(file_path)) | |
| models.append(ModelInfo( | |
| name=file_path.stem, | |
| size=file_path.stat().st_size, | |
| format=file_path.suffix[1:], | |
| modality=info.get("modality", "unknown"), | |
| architecture=info.get("architecture") | |
| )) | |
| except Exception as e: | |
| logger.warning(f"Error getting info for {file_path}: {e}") | |
| return models | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| """WebSocket endpoint for real-time training updates""" | |
| await websocket.accept() | |
| active_connections[session_id] = websocket | |
| try: | |
| # Send current status if session exists | |
| if session_id in training_sessions: | |
| await websocket.send_json({ | |
| "type": "training_update", | |
| "data": training_sessions[session_id] | |
| }) | |
| # Keep connection alive | |
| while True: | |
| await websocket.receive_text() | |
| except WebSocketDisconnect: | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| except Exception as e: | |
| logger.error(f"WebSocket error for session {session_id}: {e}") | |
| if session_id in active_connections: | |
| del active_connections[session_id] | |
| # ==================== NEW ADVANCED ENDPOINTS ==================== | |
| # Token Management Endpoints | |
| async def token_management_page(request: Request): | |
| """Token management page""" | |
| return templates.TemplateResponse("token-management.html", {"request": request}) | |
| async def save_token( | |
| name: str = Form(...), | |
| token: str = Form(...), | |
| token_type: str = Form("read"), | |
| description: str = Form(""), | |
| is_default: bool = Form(False) | |
| ): | |
| """Save HF token""" | |
| try: | |
| success = token_manager.save_token(name, token, token_type, description, is_default) | |
| if success: | |
| return {"success": True, "message": f"Token '{name}' saved successfully"} | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to save token") | |
| except Exception as e: | |
| logger.error(f"Error saving token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_tokens(): | |
| """List all saved tokens""" | |
| try: | |
| tokens = token_manager.list_tokens() | |
| return {"tokens": tokens} | |
| except Exception as e: | |
| logger.error(f"Error listing tokens: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_token(token_name: str): | |
| """Delete a token""" | |
| try: | |
| success = token_manager.delete_token(token_name) | |
| if success: | |
| return {"success": True, "message": f"Token '{token_name}' deleted"} | |
| else: | |
| raise HTTPException(status_code=404, detail="Token not found") | |
| except Exception as e: | |
| logger.error(f"Error deleting token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def set_default_token(token_name: str): | |
| """Set token as default""" | |
| try: | |
| success = token_manager.set_default_token(token_name) | |
| if success: | |
| return {"success": True, "message": f"Token '{token_name}' set as default"} | |
| else: | |
| raise HTTPException(status_code=404, detail="Token not found") | |
| except Exception as e: | |
| logger.error(f"Error setting default token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_token(token: str = Form(...)): | |
| """Validate HF token""" | |
| try: | |
| result = token_manager.validate_token(token) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error validating token: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_token_for_task(task_type: str): | |
| """Get appropriate token for specific task""" | |
| try: | |
| # Get token for task | |
| token = token_manager.get_token_for_task(task_type) | |
| if not token: | |
| raise HTTPException(status_code=404, detail=f"No suitable token found for task: {task_type}") | |
| # Get token information | |
| tokens = token_manager.list_tokens() | |
| token_info = None | |
| # Find which token was selected | |
| for t in tokens: | |
| test_token = token_manager.get_token(t['name']) | |
| if test_token == token: | |
| token_info = t | |
| break | |
| if not token_info: | |
| # Token from environment variable | |
| token_info = { | |
| 'name': f'{task_type}_token', | |
| 'type': task_type, | |
| 'description': f'رمز من متغيرات البيئة للمهمة: {task_type}', | |
| 'last_used': None, | |
| 'usage_count': 0 | |
| } | |
| # Get token type information | |
| type_info = token_manager.token_types.get(token_info['type'], {}) | |
| return { | |
| "success": True, | |
| "task_type": task_type, | |
| "token_info": { | |
| "token_name": token_info['name'], | |
| "type": token_info['type'], | |
| "type_name": type_info.get('name', token_info['type']), | |
| "description": token_info['description'], | |
| "security_level": type_info.get('security_level', 'medium'), | |
| "recommended_for": type_info.get('recommended_for', 'general'), | |
| "last_used": token_info.get('last_used'), | |
| "usage_count": token_info.get('usage_count', 0) | |
| } | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting token for task {task_type}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Medical Dataset Endpoints | |
| async def medical_datasets_page(request: Request): | |
| """Medical datasets management page""" | |
| return templates.TemplateResponse("medical-datasets.html", {"request": request}) | |
| async def list_medical_datasets(): | |
| """List supported medical datasets""" | |
| try: | |
| datasets = medical_dataset_manager.list_supported_datasets() | |
| return {"datasets": datasets} | |
| except Exception as e: | |
| logger.error(f"Error listing medical datasets: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def load_medical_dataset( | |
| dataset_name: str = Form(...), | |
| streaming: bool = Form(True), | |
| split: str = Form("train") | |
| ): | |
| """Load medical dataset""" | |
| try: | |
| # Get appropriate token for medical datasets (fine-grained preferred) | |
| hf_token = token_manager.get_token_for_task('medical') | |
| if not hf_token: | |
| logger.warning("No suitable token found for medical datasets, trying default") | |
| hf_token = token_manager.get_token() | |
| dataset_info = await medical_dataset_manager.load_dataset( | |
| dataset_name=dataset_name, | |
| streaming=streaming, | |
| split=split, | |
| token=hf_token | |
| ) | |
| return { | |
| "success": True, | |
| "dataset_info": { | |
| "name": dataset_info['config']['name'], | |
| "size_gb": dataset_info['config']['size_gb'], | |
| "num_samples": dataset_info['config']['num_samples'], | |
| "streaming": dataset_info['streaming'] | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading medical dataset: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Memory and Performance Endpoints | |
| async def get_memory_info(): | |
| """Get current memory information""" | |
| try: | |
| memory_info = memory_manager.get_memory_info() | |
| return memory_info | |
| except Exception as e: | |
| logger.error(f"Error getting memory info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_performance_info(): | |
| """Get system performance information""" | |
| try: | |
| memory_info = memory_manager.get_memory_info() | |
| recommendations = memory_manager.get_memory_recommendations() | |
| return { | |
| "memory": memory_info, | |
| "recommendations": recommendations, | |
| "cpu_cores": cpu_optimizer.cpu_count, | |
| "optimizations_applied": cpu_optimizer.optimizations_applied | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting performance info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def force_memory_cleanup(): | |
| """Force memory cleanup""" | |
| try: | |
| memory_manager.force_cleanup() | |
| return {"success": True, "message": "Memory cleanup completed"} | |
| except Exception as e: | |
| logger.error(f"Error during memory cleanup: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Google Models Support | |
| async def list_google_models(): | |
| """List available Google models""" | |
| try: | |
| google_models = [ | |
| { | |
| "name": "google/medsiglip-448", | |
| "description": "Medical SigLIP model for medical image-text understanding", | |
| "type": "vision-language", | |
| "size_gb": 1.1, | |
| "modality": "multimodal", | |
| "medical_specialized": True | |
| }, | |
| { | |
| "name": "google/gemma-3n-E4B-it", | |
| "description": "Gemma 3 model for instruction following", | |
| "type": "language", | |
| "size_gb": 8.5, | |
| "modality": "text", | |
| "medical_specialized": False | |
| } | |
| ] | |
| return {"models": google_models} | |
| except Exception as e: | |
| logger.error(f"Error listing Google models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Database Management API Endpoints | |
| async def get_all_databases(): | |
| """Get all configured databases""" | |
| try: | |
| databases = platform_db_manager.get_all_databases() | |
| selected = platform_db_manager.get_selected_databases() | |
| return { | |
| "success": True, | |
| "databases": databases, | |
| "selected": selected, | |
| "total": len(databases) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting databases: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_databases(request: DatabaseSearchRequest): | |
| """Search for databases on Hugging Face""" | |
| try: | |
| results = await platform_db_manager.search_huggingface_datasets( | |
| query=request.query, | |
| limit=request.limit | |
| ) | |
| return { | |
| "success": True, | |
| "results": results, | |
| "count": len(results), | |
| "query": request.query | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching databases: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def add_database(database_info: DatabaseInfo): | |
| """Add a new database to the configuration""" | |
| try: | |
| success = await platform_db_manager.add_database(database_info.dict()) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"Database {database_info.dataset_id} added successfully" | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to add database") | |
| except Exception as e: | |
| logger.error(f"Error adding database: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_database(dataset_id: str): | |
| """Validate a dataset""" | |
| try: | |
| validation_result = await platform_db_manager.validate_dataset(dataset_id) | |
| return { | |
| "success": True, | |
| "validation": validation_result, | |
| "dataset_id": dataset_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error validating database: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def select_databases(request: DatabaseSelectionRequest): | |
| """Select databases for use""" | |
| try: | |
| results = [] | |
| for database_id in request.database_ids: | |
| success = platform_db_manager.select_database(database_id) | |
| results.append({ | |
| "database_id": database_id, | |
| "success": success | |
| }) | |
| return { | |
| "success": True, | |
| "results": results, | |
| "selected": platform_db_manager.get_selected_databases() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error selecting databases: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def remove_database(database_id: str): | |
| """Remove a database from configuration""" | |
| try: | |
| success = platform_db_manager.remove_database(database_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"Database {database_id} removed successfully" | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to remove database") | |
| except Exception as e: | |
| logger.error(f"Error removing database: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_database_info(database_id: str): | |
| """Get detailed information about a specific database""" | |
| try: | |
| database_info = platform_db_manager.get_database_info(database_id) | |
| if database_info: | |
| return { | |
| "success": True, | |
| "database": database_info | |
| } | |
| else: | |
| raise HTTPException(status_code=404, detail="Database not found") | |
| except Exception as e: | |
| logger.error(f"Error getting database info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_databases_by_category(category: str): | |
| """Get databases filtered by category""" | |
| try: | |
| databases = platform_db_manager.get_databases_by_category(category) | |
| return { | |
| "success": True, | |
| "databases": databases, | |
| "category": category, | |
| "count": len(databases) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting databases by category: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def load_selected_databases(max_samples: int = 1000): | |
| """Load data from selected databases""" | |
| try: | |
| loaded_data = await platform_db_manager.load_selected_datasets(max_samples) | |
| return { | |
| "success": True, | |
| "loaded_datasets": loaded_data, | |
| "total_datasets": len(loaded_data) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading selected databases: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Models Management API Endpoints | |
| async def get_all_models(): | |
| """Get all configured models""" | |
| try: | |
| models = models_manager.get_all_models() | |
| teachers = models_manager.get_selected_teachers() | |
| student = models_manager.get_selected_student() | |
| return { | |
| "success": True, | |
| "models": models, | |
| "selected_teachers": teachers, | |
| "selected_student": student, | |
| "total": len(models) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_teacher_models(): | |
| """Get all teacher models""" | |
| try: | |
| teachers = models_manager.get_teacher_models() | |
| selected = models_manager.get_selected_teachers() | |
| return { | |
| "success": True, | |
| "teachers": teachers, | |
| "selected": selected, | |
| "total": len(teachers) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting teacher models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_student_models(): | |
| """Get all student models""" | |
| try: | |
| students = models_manager.get_student_models() | |
| selected = models_manager.get_selected_student() | |
| return { | |
| "success": True, | |
| "students": students, | |
| "selected": selected, | |
| "total": len(students) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting student models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def search_models(request: ModelSearchRequest): | |
| """Search for models on Hugging Face""" | |
| try: | |
| results = await models_manager.search_huggingface_models( | |
| query=request.query, | |
| limit=request.limit, | |
| model_type=request.model_type | |
| ) | |
| return { | |
| "success": True, | |
| "results": results, | |
| "count": len(results), | |
| "query": request.query | |
| } | |
| except Exception as e: | |
| logger.error(f"Error searching models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def add_model(model_info: Dict[str, Any]): | |
| """Add a new model to the configuration""" | |
| try: | |
| success = await models_manager.add_model(model_info) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"Model {model_info.get('model_id')} added successfully" | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to add model") | |
| except Exception as e: | |
| logger.error(f"Error adding model: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def validate_model(model_id: str): | |
| """Validate a model""" | |
| try: | |
| validation_result = await models_manager.validate_model(model_id) | |
| return { | |
| "success": True, | |
| "validation": validation_result, | |
| "model_id": model_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error validating model: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def select_models(request: ModelSelectionRequest): | |
| """Select teacher and student models""" | |
| try: | |
| results = [] | |
| # Select teacher models | |
| for teacher_id in request.teacher_models: | |
| success = models_manager.select_teacher(teacher_id) | |
| results.append({ | |
| "model_id": teacher_id, | |
| "type": "teacher", | |
| "success": success | |
| }) | |
| # Select student model | |
| if request.student_model is not None: | |
| success = models_manager.select_student(request.student_model) | |
| results.append({ | |
| "model_id": request.student_model, | |
| "type": "student", | |
| "success": success | |
| }) | |
| return { | |
| "success": True, | |
| "results": results, | |
| "selected_teachers": models_manager.get_selected_teachers(), | |
| "selected_student": models_manager.get_selected_student() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error selecting models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def remove_model(model_id: str): | |
| """Remove a model from configuration""" | |
| try: | |
| success = models_manager.remove_model(model_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"Model {model_id} removed successfully" | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="Failed to remove model") | |
| except Exception as e: | |
| logger.error(f"Error removing model: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_model_info(model_id: str): | |
| """Get detailed information about a specific model""" | |
| try: | |
| model_info = models_manager.get_model_info(model_id) | |
| if model_info: | |
| return { | |
| "success": True, | |
| "model": model_info | |
| } | |
| else: | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| except Exception as e: | |
| logger.error(f"Error getting model info: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=int(os.getenv("PORT", 7860)), | |
| reload=False, | |
| log_level="info" | |
| ) | |