#!/usr/bin/env python3 """ Module de chargement et gestion du modèle MLflow. Ce module encapsule la logique de chargement du modèle depuis Hugging Face Hub via MLflow, avec gestion des erreurs et versioning. """ from typing import Any, Optional from fastapi import HTTPException from huggingface_hub import hf_hub_download # Configuration HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model" MODEL_FILENAME = "model/model.pkl" # Cache global du modèle _model_cache: Optional[Any] = None def load_model(force_reload: bool = False) -> Any: """ Charge le modèle depuis Hugging Face Hub via MLflow. Cette fonction implémente un système de cache pour éviter de recharger le modèle à chaque appel. Le modèle est chargé une seule fois au démarrage de l'application et mis en cache. Args: force_reload: Si True, force le rechargement du modèle même s'il est en cache. Returns: Le modèle MLflow chargé et prêt pour l'inférence. Raises: HTTPException: 500 si le modèle ne peut pas être chargé. Examples: >>> model = load_model() >>> # Utiliser le modèle pour prédiction >>> predictions = model.predict(X) """ global _model_cache # Retourner le modèle en cache si disponible if _model_cache is not None and not force_reload: return _model_cache try: import joblib print(f"🔄 Chargement du modèle depuis HF Hub: {HF_MODEL_REPO}") # Télécharger le modèle depuis Hugging Face Hub model_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, repo_type="model" ) print(f"📦 Modèle téléchargé: {model_path}") # Charger le modèle avec joblib model = joblib.load(model_path) # Mettre en cache _model_cache = model print(f"✅ Modèle chargé avec succès: {type(model).__name__}") return model except Exception as e: error_msg = f"❌ Erreur lors du chargement du modèle: {str(e)}" print(error_msg) raise HTTPException( status_code=500, detail={ "error": "Model loading failed", "message": str(e), "model_repo": HF_MODEL_REPO, "solution": "Vérifiez que le modèle est disponible sur HF Hub et correctement entraîné", }, ) def get_model_info() -> dict: """ Retourne les informations sur le modèle chargé. Returns: Dict contenant les métadonnées du modèle. Raises: HTTPException: 500 si le modèle n'est pas chargé. """ try: model = load_model() return { "status": "✅ Modèle chargé", "model_type": type(model).__name__, "hf_hub_repo": HF_MODEL_REPO, "model_file": MODEL_FILENAME, "cached": _model_cache is not None, } except Exception as e: raise HTTPException( status_code=500, detail={"error": "Model info unavailable", "message": str(e)}, ) def load_preprocessing_artifacts(run_id: str) -> dict: """ Charge les artifacts de preprocessing (scaler, encoders) depuis MLflow. Args: run_id: ID du run MLflow contenant les artifacts. Returns: Dict contenant les artifacts de preprocessing. Raises: HTTPException: 500 si les artifacts ne peuvent pas être chargés. Note: Cette fonction sera implémentée quand les preprocessing artifacts seront disponibles dans le modèle HF Hub. """ raise NotImplementedError( "Le chargement des preprocessing artifacts sera implémenté " "lors de l'intégration complète avec MLflow" ) if __name__ == "__main__": # Test de chargement du modèle print("=" * 80) print("TEST DE CHARGEMENT DU MODÈLE") print("=" * 80) try: model = load_model() print("\n✅ Test réussi!") print(f"Type de modèle: {type(model).__name__}") # Afficher les infos info = get_model_info() print("\nInformations du modèle:") for key, value in info.items(): print(f" {key}: {value}") except Exception as e: print(f"\n❌ Test échoué: {e}")