File size: 4,357 Bytes
bffe28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""
Module de chargement et gestion du modèle MLflow.

Ce module encapsule la logique de chargement du modèle depuis Hugging Face Hub
via MLflow, avec gestion des erreurs et versioning.
"""
from typing import Any, Optional

from fastapi import HTTPException
from huggingface_hub import hf_hub_download

# Configuration
HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model"
MODEL_FILENAME = "model/model.pkl"

# Cache global du modèle
_model_cache: Optional[Any] = None


def load_model(force_reload: bool = False) -> Any:
    """
    Charge le modèle depuis Hugging Face Hub via MLflow.

    Cette fonction implémente un système de cache pour éviter de recharger
    le modèle à chaque appel. Le modèle est chargé une seule fois au démarrage
    de l'application et mis en cache.

    Args:
        force_reload: Si True, force le rechargement du modèle même s'il est en cache.

    Returns:
        Le modèle MLflow chargé et prêt pour l'inférence.

    Raises:
        HTTPException: 500 si le modèle ne peut pas être chargé.

    Examples:
        >>> model = load_model()
        >>> # Utiliser le modèle pour prédiction
        >>> predictions = model.predict(X)
    """
    global _model_cache

    # Retourner le modèle en cache si disponible
    if _model_cache is not None and not force_reload:
        return _model_cache

    try:
        import joblib

        print(f"🔄 Chargement du modèle depuis HF Hub: {HF_MODEL_REPO}")

        # Télécharger le modèle depuis Hugging Face Hub
        model_path = hf_hub_download(
            repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, repo_type="model"
        )

        print(f"📦 Modèle téléchargé: {model_path}")

        # Charger le modèle avec joblib
        model = joblib.load(model_path)

        # Mettre en cache
        _model_cache = model

        print(f"✅ Modèle chargé avec succès: {type(model).__name__}")
        return model

    except Exception as e:
        error_msg = f"❌ Erreur lors du chargement du modèle: {str(e)}"
        print(error_msg)
        raise HTTPException(
            status_code=500,
            detail={
                "error": "Model loading failed",
                "message": str(e),
                "model_repo": HF_MODEL_REPO,
                "solution": "Vérifiez que le modèle est disponible sur HF Hub et correctement entraîné",
            },
        )


def get_model_info() -> dict:
    """
    Retourne les informations sur le modèle chargé.

    Returns:
        Dict contenant les métadonnées du modèle.

    Raises:
        HTTPException: 500 si le modèle n'est pas chargé.
    """
    try:
        model = load_model()

        return {
            "status": "✅ Modèle chargé",
            "model_type": type(model).__name__,
            "hf_hub_repo": HF_MODEL_REPO,
            "model_file": MODEL_FILENAME,
            "cached": _model_cache is not None,
        }

    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail={"error": "Model info unavailable", "message": str(e)},
        )


def load_preprocessing_artifacts(run_id: str) -> dict:
    """
    Charge les artifacts de preprocessing (scaler, encoders) depuis MLflow.

    Args:
        run_id: ID du run MLflow contenant les artifacts.

    Returns:
        Dict contenant les artifacts de preprocessing.

    Raises:
        HTTPException: 500 si les artifacts ne peuvent pas être chargés.

    Note:
        Cette fonction sera implémentée quand les preprocessing artifacts
        seront disponibles dans le modèle HF Hub.
    """
    raise NotImplementedError(
        "Le chargement des preprocessing artifacts sera implémenté "
        "lors de l'intégration complète avec MLflow"
    )


if __name__ == "__main__":
    # Test de chargement du modèle
    print("=" * 80)
    print("TEST DE CHARGEMENT DU MODÈLE")
    print("=" * 80)

    try:
        model = load_model()
        print("\n✅ Test réussi!")
        print(f"Type de modèle: {type(model).__name__}")

        # Afficher les infos
        info = get_model_info()
        print("\nInformations du modèle:")
        for key, value in info.items():
            print(f"  {key}: {value}")

    except Exception as e:
        print(f"\n❌ Test échoué: {e}")