ASI-Engineer commited on
Commit
bffe28b
·
verified ·
1 Parent(s): 01b325b

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. .env.production +13 -0
  2. Dockerfile +37 -0
  3. README.md +207 -74
  4. README_HF.md +37 -21
  5. app.py +217 -103
  6. requirements.txt +103 -9
  7. src/__init__.py +1 -0
  8. src/auth.py +99 -0
  9. src/config.py +64 -0
  10. src/logger.py +223 -0
  11. src/models.py +153 -0
  12. src/preprocessing.py +243 -0
  13. src/rate_limit.py +40 -0
  14. src/schemas.py +232 -0
.env.production ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Production environment variables for HuggingFace Spaces
2
+
3
+ # Security
4
+ DEBUG=false
5
+ API_KEY=${HF_SPACE_API_KEY}
6
+
7
+ # API Configuration
8
+ API_VERSION=2.1.0
9
+ LOG_LEVEL=INFO
10
+
11
+ # HuggingFace Model
12
+ HF_MODEL_REPO=ASI-Engineer/employee-turnover-model
13
+ MODEL_FILENAME=model/model.pkl
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Installer les dépendances système
6
+ RUN apt-get update && apt-get install -y \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copier les fichiers de dépendances
11
+ COPY requirements.txt .
12
+
13
+ # Installer les dépendances Python
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copier le code de l'application
17
+ COPY app.py .
18
+ COPY src/ ./src/
19
+ COPY .env.example .env
20
+
21
+ # Créer le dossier logs
22
+ RUN mkdir -p logs
23
+
24
+ # Exposer le port
25
+ EXPOSE 8000
26
+
27
+ # Variables d'environnement par défaut
28
+ ENV DEBUG=false
29
+ ENV LOG_LEVEL=INFO
30
+ ENV API_KEY=change-me-in-production
31
+
32
+ # Healthcheck
33
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
34
+ CMD curl -f http://localhost:8000/health || exit 1
35
+
36
+ # Commande de démarrage
37
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
README.md CHANGED
@@ -1,106 +1,239 @@
1
- ---
2
- title: OC P5 - API ML Déployée
3
- emoji: 🎯
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- # 🎯 Employee Turnover Prediction - DEV Environment
14
-
15
- Interface Gradio pour tester le modèle de prédiction de départ des employés (turnover).
16
-
17
- ## 🚀 Modèle ML
18
-
19
- - **Algorithme**: XGBoost optimisé avec RandomizedSearchCV
20
- - **Équilibrage**: SMOTE pour gérer le déséquilibre de classes (ratio 5:1)
21
- - **Tracking**: MLflow pour versioning et reproductibilité
22
- - **Métriques**: F1-Score optimisé (0.51), Accuracy 79%
23
- - **Stockage**: [Hugging Face Hub](https://huggingface.co/ASI-Engineer/employee-turnover-model)
24
-
25
- ## 📊 Fonctionnalités
26
-
27
- - **Status Checker**: Vérifier l'état du modèle et les métriques
28
- - **API Simple**: Interface Gradio pour tests rapides
29
- - **Chargement automatique**: Modèle téléchargé depuis HF Hub au démarrage
30
-
31
- ## 🔧 Architecture
32
-
33
- ```python
34
- # Chargement du modèle depuis HF Hub
35
- model_path = hf_hub_download(
36
- repo_id="ASI-Engineer/employee-turnover-model",
37
- filename="model/model.pkl"
38
- )
39
- model = mlflow.sklearn.load_model(str(Path(model_path).parent))
40
  ```
41
 
42
- ## 🛠️ Installation & Développement
43
 
44
  ### Prérequis
45
  - Python 3.12+
46
- - Poetry (gestionnaire de dépendances)
 
47
 
48
- ### Installation avec Poetry
49
 
50
  ```bash
51
- # Installer Poetry (si pas déjà fait)
52
- curl -sSL https://install.python-poetry.org | python3 -
 
53
 
54
- # Installer les dépendances
55
  poetry install
56
 
57
- # Activer l'environnement virtuel
58
- poetry shell
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Lancer le pipeline d'entraînement
61
- poetry run python main.py
62
 
63
- # Lancer l'interface Gradio
64
- poetry run python app.py
 
 
 
 
65
  ```
66
 
67
- ### Requirements.txt pour HF Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- Le fichier `requirements.txt` est **minimal et optimisé** pour HF Spaces (seulement gradio, huggingface-hub, joblib).
70
 
71
- Il est **généré automatiquement** par le CI/CD lors des déploiements.
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- Pour le générer manuellement :
74
  ```bash
75
- ./scripts/export_requirements.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ```
77
 
78
- ### Tests et Linting
79
 
80
  ```bash
81
- # Formater le code
82
- poetry run black .
83
 
84
- # Linter
85
- poetry run flake8 .
86
 
87
- # Tests
88
- poetry run pytest --cov=ml_model tests/
89
  ```
90
 
91
- ## 📈 Métriques
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- - **F1-Score**: 0.5136
94
- - **Accuracy**: 79%
95
- - **Données**: 1470 échantillons, 50 features
96
- - **Classes**: {0: 1233, 1: 237} - Ratio 5.20:1
97
 
98
- ## 🔗 Liens
 
 
 
 
99
 
100
- - **Modèle**: [employee-turnover-model](https://huggingface.co/ASI-Engineer/employee-turnover-model)
101
- - **GitHub**: [OC_P5](https://github.com/chaton59/OC_P5)
102
- - **CI/CD**: GitHub Actions avec déploiement automatique
 
103
 
104
- Ce Space est synchronisé automatiquement via CI/CD depuis la branche `dev` du repository GitHub.
105
 
106
- **Repository**: [chaton59/OC_P5](https://github.com/chaton59/OC_P5)
 
 
1
+ # 🚀 Employee Turnover Prediction API - v2.1.0
2
+
3
+ ## 📊 Vue d'ensemble
4
+
5
+ API REST de prédiction du turnover des employés basée sur un modèle XGBoost avec SMOTE.
6
+
7
+ **✨ Nouveautés v2.1.0** :
8
+ - 📝 Logging structuré JSON
9
+ - 🛡️ Rate limiting (20 req/min par IP)
10
+ - ⚡ Gestion d'erreurs améliorée
11
+ - 📊 Monitoring des performances
12
+ - 🔐 Authentification API Key
13
+
14
+ ## 🏗️ Architecture
15
+
16
+ ```
17
+ OC_P5/
18
+ ├── app.py # Point d'entrée FastAPI
19
+ ├── src/
20
+ │ ├── auth.py # Authentification API Key
21
+ │ ├── config.py # Configuration centralisée
22
+ │ ├── logger.py # Logging structuré (NOUVEAU)
23
+ │ ├── models.py # Chargement modèle HF Hub
24
+ │ ├── preprocessing.py # Pipeline preprocessing
25
+ │ ├── rate_limit.py # Rate limiting (NOUVEAU)
26
+ │ └── schemas.py # Validation Pydantic
27
+ ├── tests/ # Suite pytest (33 tests, 88% couverture)
28
+ ├── logs/ # Logs JSON (NOUVEAU)
29
+ │ ├── api.log # Tous les logs
30
+ │ └── error.log # Erreurs uniquement
31
+ ├── docs/ # Documentation
32
+ ├── ml_model/ # Scripts training
33
+ └── data/ # Données sources
 
 
 
 
 
 
34
  ```
35
 
36
+ ## 🚀 Installation
37
 
38
  ### Prérequis
39
  - Python 3.12+
40
+ - Poetry 1.7+
41
+ - Git
42
 
43
+ ### Setup rapide
44
 
45
  ```bash
46
+ # 1. Cloner le repo
47
+ git clone https://github.com/chaton59/OC_P5.git
48
+ cd OC_P5
49
 
50
+ # 2. Installer les dépendances
51
  poetry install
52
 
53
+ # 3. Configurer l'environnement
54
+ cp .env.example .env
55
+ # Éditer .env avec vos valeurs
56
+
57
+ # 4. Lancer l'API
58
+ poetry run uvicorn app:app --reload
59
+
60
+ # 5. Accéder à la documentation
61
+ # http://localhost:8000/docs
62
+ ```
63
+
64
+ ## 📝 Configuration (.env)
65
+
66
+ ```bash
67
+ # Mode développement (désactive auth + active logs détaillés)
68
+ DEBUG=true
69
 
70
+ # API Key (requis en production)
71
+ API_KEY=your-secret-key-here
72
 
73
+ # Logging (DEBUG, INFO, WARNING, ERROR, CRITICAL)
74
+ LOG_LEVEL=INFO
75
+
76
+ # HuggingFace Model
77
+ HF_MODEL_REPO=ASI-Engineer/employee-turnover-model
78
+ MODEL_FILENAME=model/model.pkl
79
  ```
80
 
81
+ ## 🔒 Authentification
82
+
83
+ ### Mode DEBUG (développement)
84
+ ```bash
85
+ # L'API Key n'est PAS requise
86
+ curl http://localhost:8000/predict -H "Content-Type: application/json" -d '{...}'
87
+ ```
88
+
89
+ ### Mode PRODUCTION
90
+ ```bash
91
+ # L'API Key est REQUISE
92
+ curl http://localhost:8000/predict \
93
+ -H "X-API-Key: your-secret-key" \
94
+ -H "Content-Type: application/json" \
95
+ -d '{...}'
96
+ ```
97
 
98
+ ## 📡 Endpoints
99
 
100
+ ### 🏥 Health Check
101
+ ```bash
102
+ GET /health
103
+
104
+ # Réponse
105
+ {
106
+ "status": "healthy",
107
+ "model_loaded": true,
108
+ "model_type": "Pipeline",
109
+ "version": "2.1.0"
110
+ }
111
+ ```
112
 
113
+ ### 🔮 Prédiction
114
  ```bash
115
+ POST /predict
116
+ Content-Type: application/json
117
+ X-API-Key: your-key (en production)
118
+
119
+ # Exemple payload (voir docs/API_GUIDE.md pour tous les champs)
120
+ {
121
+ "satisfaction_employee_environnement": 3,
122
+ "satisfaction_employee_nature_travail": 4,
123
+ "satisfaction_employee_equipe": 5,
124
+ "satisfaction_employee_equilibre_pro_perso": 3,
125
+ "note_evaluation_actuelle": 85,
126
+ "annees_depuis_la_derniere_promotion": 2,
127
+ "nombre_formations_realisees": 3,
128
+ ...
129
+ }
130
+
131
+ # Réponse
132
+ {
133
+ "prediction": 0, # 0 = reste, 1 = part
134
+ "probability_0": 0.85, # Probabilité de rester
135
+ "probability_1": 0.15, # Probabilité de partir
136
+ "risk_level": "Low" # Low, Medium, High
137
+ }
138
+ ```
139
+
140
+ ## 📊 Logging
141
+
142
+ ### Logs structurés JSON
143
+
144
+ **Fichiers** :
145
+ - `logs/api.log` : Tous les logs
146
+ - `logs/error.log` : Erreurs uniquement
147
+
148
+ **Format** :
149
+ ```json
150
+ {
151
+ "timestamp": "2025-12-26T10:30:45",
152
+ "level": "INFO",
153
+ "logger": "employee_turnover_api",
154
+ "message": "Request POST /predict",
155
+ "method": "POST",
156
+ "path": "/predict",
157
+ "status_code": 200,
158
+ "duration_ms": 23.45,
159
+ "client_host": "127.0.0.1"
160
+ }
161
+ ```
162
+
163
+ ## 🛡️ Rate Limiting
164
+
165
+ **Configuration** :
166
+ - **Développement** : Désactivé (DEBUG=true)
167
+ - **Production** : 20 requêtes/minute par IP ou API Key
168
+
169
+ **En cas de dépassement** :
170
+ ```json
171
+ {
172
+ "error": "Rate limit exceeded",
173
+ "message": "20 per 1 minute"
174
+ }
175
  ```
176
 
177
+ ## Tests
178
 
179
  ```bash
180
+ # Tous les tests
181
+ poetry run pytest tests/ -v
182
 
183
+ # Avec couverture
184
+ poetry run pytest tests/ --cov --cov-report=html
185
 
186
+ # Voir rapport HTML
187
+ open htmlcov/index.html
188
  ```
189
 
190
+ **Résultats** :
191
+ - ✅ 33 tests passés
192
+ - 📊 88% de couverture globale
193
+
194
+ ## 🚀 Déploiement
195
+
196
+ ### Variables d'environnement requises
197
+ ```bash
198
+ DEBUG=false
199
+ API_KEY=<votre-clé-sécurisée>
200
+ LOG_LEVEL=INFO
201
+ ```
202
+
203
+ ### HuggingFace Spaces
204
+ Prêt pour déploiement avec `app.py` et `requirements.txt`
205
+
206
+ ## 📚 Documentation
207
+
208
+ - **API Interactive** : http://localhost:8000/docs
209
+ - **ReDoc** : http://localhost:8000/redoc
210
+ - **Guide complet** : [docs/API_GUIDE.md](docs/API_GUIDE.md)
211
+ - **Standards** : [docs/standards.md](docs/standards.md)
212
+ - **Couverture tests** : [docs/TEST_COVERAGE.md](docs/TEST_COVERAGE.md)
213
+
214
+ ## 📦 Dépendances principales
215
+
216
+ - **FastAPI** 0.115.14 : Framework web
217
+ - **Pydantic** 2.12.5 : Validation données
218
+ - **XGBoost** 2.1.3 : Modèle ML
219
+ - **SlowAPI** 0.1.9 : Rate limiting
220
+ - **python-json-logger** 4.0.0 : Logs structurés
221
+ - **pytest** 9.0.2 : Tests
222
 
223
+ ## 🔄 Changelog
 
 
 
224
 
225
+ ### v2.1.0 (26 décembre 2025)
226
+ - ✨ Système de logging structuré JSON
227
+ - 🛡️ Rate limiting avec SlowAPI
228
+ - ⚡ Amélioration gestion d'erreurs
229
+ - 📊 Monitoring des performances
230
 
231
+ ### v2.0.0 (26 décembre 2025)
232
+ - Suite de tests complète (33 tests)
233
+ - 🔐 Authentification API Key
234
+ - 📊 88% de couverture de code
235
 
236
+ ## 👥 Auteurs
237
 
238
+ - **Projet** : OpenClassrooms P5
239
+ - **Repo** : [github.com/chaton59/OC_P5](https://github.com/chaton59/OC_P5)
README_HF.md CHANGED
@@ -1,33 +1,49 @@
1
  ---
2
- title: Employee Turnover Prediction - DEV
3
- emoji: 🎯
4
  colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- # 🎯 Employee Turnover Prediction - Environment DEV
13
 
14
- Interface de test pour prédire le risque de départ des employés.
15
 
16
- ## 🚀 Modèle
17
 
18
- - **Algorithme**: XGBoost avec RandomizedSearchCV
19
- - **Équilibrage**: SMOTE pour classes déséquilibrées (ratio 5:1)
20
- - **Tracking**: MLflow pour versioning et reproductibilité
21
- - **Métriques**: Optimisé pour F1-Score
 
 
22
 
23
- ## 📊 Utilisation
24
 
25
- 1. Ajustez les paramètres de l'employé (satisfaction, évaluation, projets, etc.)
26
- 2. Cliquez sur "Prédire le risque de départ"
27
- 3. Obtenez la probabilité de turnover et les recommandations
28
 
29
- ## 🔧 Développement
30
 
31
- Ce Space est synchronisé automatiquement via CI/CD depuis la branche `dev` du repository GitHub.
 
 
32
 
33
- **Repository**: [chaton59/OC_P5](https://github.com/chaton59/OC_P5)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Employee Turnover Prediction API
3
+ emoji: 👔
4
  colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: true
8
+ license: mit
9
+ app_port: 8000
10
  ---
11
 
12
+ # Employee Turnover Prediction API 🚀
13
 
14
+ API de prédiction du turnover des employés avec XGBoost + SMOTE.
15
 
16
+ ## 🎯 Fonctionnalités
17
 
18
+ - Prédiction de turnover (0 = reste, 1 = part)
19
+ - 📊 Probabilités et niveau de risque (Low/Medium/High)
20
+ - 🔐 Authentification API Key
21
+ - 📝 Logs structurés JSON
22
+ - 🛡️ Rate limiting (20 req/min)
23
+ - 📚 Documentation OpenAPI/Swagger
24
 
25
+ ## 🔗 Endpoints
26
 
27
+ - **Docs** : `/docs` - Documentation interactive
28
+ - **Health** : `/health` - Status de l'API
29
+ - **Predict** : `/predict` - Prédiction de turnover
30
 
31
+ ## 🚀 Utilisation
32
 
33
+ ```bash
34
+ # Health check
35
+ curl https://asi-engineer-employee-turnover-api.hf.space/health
36
 
37
+ # Prédiction
38
+ curl -X POST https://asi-engineer-employee-turnover-api.hf.space/predict \
39
+ -H "Content-Type: application/json" \
40
+ -d '{
41
+ "satisfaction_employee_environnement": 3,
42
+ "satisfaction_employee_nature_travail": 4,
43
+ ...
44
+ }'
45
+ ```
46
+
47
+ ## 📚 Documentation complète
48
+
49
+ Voir [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète.
app.py CHANGED
@@ -1,138 +1,252 @@
1
  #!/usr/bin/env python3
2
  """
3
- Interface Gradio pour tester le modèle Employee Turnover en production.
4
 
5
- Déploiement sur Hugging Face Spaces pour tests rapides.
6
- Version de démonstration - Interface complète en développement.
 
 
 
7
  """
8
- import gradio as gr
9
- from huggingface_hub import hf_hub_download
10
 
11
- # Configuration
12
- HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model"
 
 
13
 
 
 
 
 
 
 
 
14
 
15
- def load_model():
 
 
 
 
 
 
16
  """
17
- Charge le modèle depuis Hugging Face Hub.
18
 
19
- En production (HF Spaces), charge uniquement depuis HF Hub.
20
- Le fallback MLflow local n'est disponible qu'en développement local.
21
  """
 
 
 
 
 
22
  try:
23
- import joblib
 
 
24
 
25
- # Download model pickle from HF Hub
26
- model_path = hf_hub_download(
27
- repo_id=HF_MODEL_REPO, filename="model/model.pkl", repo_type="model"
28
- )
29
- model = joblib.load(model_path)
30
- print(f"✅ Modèle chargé depuis HF Hub: {HF_MODEL_REPO}")
31
- return model, "HF Hub"
32
  except Exception as e:
33
- print(f"❌ Erreur chargement depuis HF Hub: {e}")
34
- return None, "Error"
35
-
36
-
37
- # Charger le modèle au démarrage
38
- try:
39
- model, model_source = load_model()
40
- MODEL_LOADED = model is not None
41
- except Exception as e:
42
- print(f"❌ Erreur lors du chargement du modèle: {e}")
43
- MODEL_LOADED = False
44
- model = None
45
- model_source = "Error"
46
-
47
-
48
- def get_model_info():
49
- """Retourne les informations sur le modèle."""
50
- if not MODEL_LOADED:
51
- return {
52
- "status": "❌ Modèle non disponible",
53
- "error": "Le modèle n'a pas pu être chargé",
54
- "solution": "Vérifiez que le modèle est bien enregistré sur HF Hub ou entraîné localement",
55
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- try:
58
- info = {
59
- "status": "✅ Modèle chargé avec succès",
60
- "source": model_source,
61
- "model_type": type(model).__name__,
62
- "features": "~50 features (après preprocessing)",
63
- "algorithme": "XGBoost + SMOTE",
64
- "hf_hub_repo": HF_MODEL_REPO,
65
- }
66
-
67
- info["info"] = "Interface de prédiction en développement - API FastAPI à venir"
68
- return info
69
 
70
- except Exception as e:
71
- return {"status": "✅ Modèle chargé (info limitées)", "error": str(e)}
72
 
 
 
 
 
 
 
 
 
73
 
74
- # Interface Gradio
75
- with gr.Blocks( # type: ignore[attr-defined]
76
- title="Employee Turnover Prediction - DEV", theme=gr.themes.Soft() # type: ignore[attr-defined]
77
- ) as demo:
78
- gr.Markdown("# 🎯 Prédiction du Turnover - Employee Attrition") # type: ignore[attr-defined]
79
- gr.Markdown("## Environment DEV - Test de déploiement CI/CD") # type: ignore[attr-defined]
80
 
81
- gr.Markdown( # type: ignore[attr-defined]
82
- """
83
- ### 📊 Statut du projet
84
 
85
- Ce Space est synchronisé automatiquement depuis GitHub (branche `dev`).
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- **Actuellement disponible :**
88
- - ✅ Pipeline d'entraînement MLflow complet (`main.py`)
89
- - ✅ Déploiement automatique CI/CD (GitHub Actions → HF Spaces)
90
- - ✅ Tests unitaires et linting automatisés
91
 
92
- **En développement :**
93
- - 🚧 Interface de prédiction interactive
94
- - 🚧 API FastAPI avec endpoints de prédiction
95
- - 🚧 Intégration PostgreSQL pour tracking des prédictions
96
  """
97
- )
 
 
 
 
 
98
 
99
- with gr.Row(): # type: ignore[attr-defined]
100
- with gr.Column(): # type: ignore[attr-defined]
101
- gr.Markdown("### 🔍 Informations sur le modèle") # type: ignore[attr-defined]
102
- check_btn = gr.Button("📊 Vérifier le statut du modèle", variant="primary") # type: ignore[attr-defined]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- with gr.Column(): # type: ignore[attr-defined]
105
- model_output = gr.JSON(label="Statut") # type: ignore[attr-defined]
106
 
107
- check_btn.click(fn=get_model_info, inputs=[], outputs=model_output)
 
 
 
 
 
 
 
 
 
108
 
109
- gr.Markdown("---") # type: ignore[attr-defined]
110
 
111
- gr.Markdown( # type: ignore[attr-defined]
112
- """
113
- ### 🛠️ Prochaines étapes (selon etapes.txt)
114
 
115
- 1. **Étape 3** : Développement API FastAPI
116
- - Endpoints de prédiction avec validation Pydantic
117
- - Chargement dynamique des preprocessing artifacts (scaler, encoders)
118
- - Documentation Swagger/OpenAPI automatique
119
 
120
- 2. **Étape 4** : Intégration PostgreSQL
121
- - Stockage des inputs/outputs des prédictions
122
- - Traçabilité complète des requêtes
123
 
124
- 3. **Étape 5** : Tests unitaires et fonctionnels
125
- - Tests des endpoints API
126
- - Tests de charge et performance
127
- - Couverture de code avec pytest-cov
128
 
129
- ### 📚 Documentation
130
- - **Repository GitHub** : [chaton59/OC_P5](https://github.com/chaton59/OC_P5)
131
- - **MLflow Tracking** : Disponible en local (`./scripts/start_mlflow.sh`)
132
- - **Métriques** : F1-Score optimisé, gestion classes déséquilibrées (SMOTE)
 
 
 
 
133
  """
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  if __name__ == "__main__":
138
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ API FastAPI pour le modèle Employee Turnover.
4
 
5
+ Cette API expose le modèle de prédiction de départ des employés avec :
6
+ - Validation stricte des inputs via Pydantic
7
+ - Preprocessing automatique
8
+ - Health check pour monitoring
9
+ - Documentation OpenAPI/Swagger automatique
10
  """
11
+ import time
12
+ from contextlib import asynccontextmanager
13
 
14
+ from fastapi import Depends, FastAPI, HTTPException, Request
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from slowapi import _rate_limit_exceeded_handler
17
+ from slowapi.errors import RateLimitExceeded
18
 
19
+ from src.auth import verify_api_key
20
+ from src.config import get_settings
21
+ from src.logger import logger, log_model_load, log_request
22
+ from src.models import get_model_info, load_model
23
+ from src.preprocessing import preprocess_for_prediction
24
+ from src.rate_limit import limiter
25
+ from src.schemas import EmployeeInput, HealthCheck, PredictionOutput
26
 
27
+ # Charger la configuration
28
+ settings = get_settings()
29
+ API_VERSION = settings.API_VERSION
30
+
31
+
32
+ @asynccontextmanager
33
+ async def lifespan(app: FastAPI):
34
  """
35
+ Gestion du cycle de vie de l'application.
36
 
37
+ Charge le modèle au démarrage et le garde en cache.
 
38
  """
39
+ logger.info(
40
+ "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
41
+ )
42
+
43
+ start_time = time.time()
44
  try:
45
+ # Pré-charger le modèle au démarrage
46
+ model = load_model()
47
+ duration_ms = (time.time() - start_time) * 1000
48
 
49
+ model_type = type(model).__name__
50
+ log_model_load(model_type, duration_ms, True)
51
+ logger.info(" Modèle chargé avec succès")
 
 
 
 
52
  except Exception as e:
53
+ duration_ms = (time.time() - start_time) * 1000
54
+ log_model_load("Unknown", duration_ms, False)
55
+ logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})
56
+
57
+ yield # L'application tourne
58
+
59
+ logger.info("🛑 Arrêt de l'API")
60
+
61
+
62
+ # Créer l'application FastAPI
63
+ app = FastAPI(
64
+ title="Employee Turnover Prediction API",
65
+ description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
66
+ version=API_VERSION,
67
+ lifespan=lifespan,
68
+ docs_url="/docs",
69
+ redoc_url="/redoc",
70
+ )
71
+
72
+ # Ajouter rate limiting
73
+ app.state.limiter = limiter
74
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
75
+
76
+ # Configurer CORS (autoriser tous les domaines en dev)
77
+ app.add_middleware(
78
+ CORSMiddleware,
79
+ allow_origins=["*"],
80
+ allow_credentials=True,
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ )
84
+
85
+
86
+ # Middleware de logging des requêtes
87
+ @app.middleware("http")
88
+ async def log_requests(request: Request, call_next):
89
+ """
90
+ Middleware pour logger toutes les requêtes HTTP.
91
+ """
92
+ start_time = time.time()
93
 
94
+ # Traiter la requête
95
+ response = await call_next(request)
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Calculer la durée
98
+ duration_ms = (time.time() - start_time) * 1000
99
 
100
+ # Logger
101
+ log_request(
102
+ method=request.method,
103
+ path=request.url.path,
104
+ status_code=response.status_code,
105
+ duration_ms=duration_ms,
106
+ client_host=request.client.host if request.client else None,
107
+ )
108
 
109
+ return response
 
 
 
 
 
110
 
 
 
 
111
 
112
+ @app.get("/", tags=["Root"])
113
+ async def root():
114
+ """
115
+ Endpoint racine avec informations sur l'API.
116
+ """
117
+ return {
118
+ "message": "Employee Turnover Prediction API",
119
+ "version": API_VERSION,
120
+ "docs": "/docs",
121
+ "health": "/health",
122
+ "predict": "/predict (POST)",
123
+ }
124
 
 
 
 
 
125
 
126
+ @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
127
+ async def health_check():
 
 
128
  """
129
+ Health check endpoint pour monitoring.
130
+
131
+ Vérifie que l'API est opérationnelle et que le modèle est chargé.
132
+
133
+ Returns:
134
+ HealthCheck: Status de l'API et du modèle.
135
 
136
+ Raises:
137
+ HTTPException: 503 si le modèle n'est pas disponible.
138
+ """
139
+ try:
140
+ model_info = get_model_info()
141
+
142
+ return HealthCheck(
143
+ status="healthy",
144
+ model_loaded=model_info.get("cached", False),
145
+ model_type=model_info.get("model_type", "Unknown"),
146
+ version=API_VERSION,
147
+ )
148
+ except Exception as e:
149
+ raise HTTPException(
150
+ status_code=503,
151
+ detail={
152
+ "status": "unhealthy",
153
+ "error": "Model not available",
154
+ "message": str(e),
155
+ },
156
+ )
157
 
 
 
158
 
159
+ @app.post(
160
+ "/predict",
161
+ response_model=PredictionOutput,
162
+ tags=["Prediction"],
163
+ dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
164
+ )
165
+ @limiter.limit("20/minute")
166
+ async def predict(request: Request, employee: EmployeeInput):
167
+ """
168
+ Endpoint de prédiction du turnover d'un employé.
169
 
170
+ **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
171
 
172
+ Prend en entrée les données d'un employé, applique le preprocessing
173
+ et retourne la prédiction avec les probabilités.
 
174
 
175
+ Args:
176
+ employee: Données de l'employé validées par Pydantic.
 
 
177
 
178
+ Returns:
179
+ PredictionOutput: Prédiction et probabilités.
 
180
 
181
+ Raises:
182
+ HTTPException: 401 si API key invalide ou manquante.
183
+ HTTPException: 500 si erreur lors de la prédiction.
 
184
 
185
+ Examples:
186
+ ```bash
187
+ # Avec authentification
188
+ curl -X POST http://localhost:8000/predict \\
189
+ -H "X-API-Key: your-secret-key" \\
190
+ -H "Content-Type: application/json" \\
191
+ -d '{...}'
192
+ ```
193
  """
194
+ try:
195
+ # 1. Charger le modèle
196
+ model = load_model()
197
+
198
+ # 2. Préprocessing
199
+ X = preprocess_for_prediction(employee)
200
+
201
+ # 3. Prédiction
202
+ prediction = int(model.predict(X)[0])
203
+
204
+ # 4. Probabilités (si le modèle supporte predict_proba)
205
+ try:
206
+ probabilities = model.predict_proba(X)[0]
207
+ prob_0 = float(probabilities[0])
208
+ prob_1 = float(probabilities[1])
209
+ except AttributeError:
210
+ # Si le modèle ne supporte pas predict_proba
211
+ prob_0 = 1.0 if prediction == 0 else 0.0
212
+ prob_1 = 1.0 if prediction == 1 else 0.0
213
+
214
+ # 5. Niveau de risque
215
+ if prob_1 < 0.3:
216
+ risk_level = "Low"
217
+ elif prob_1 < 0.7:
218
+ risk_level = "Medium"
219
+ else:
220
+ risk_level = "High"
221
+
222
+ return PredictionOutput(
223
+ prediction=prediction,
224
+ probability_0=prob_0,
225
+ probability_1=prob_1,
226
+ risk_level=risk_level,
227
+ )
228
+
229
+ except Exception:
230
+ logger.exception("Unexpected error during prediction")
231
+ raise HTTPException(
232
+ status_code=500,
233
+ detail={
234
+ "error": "Prediction failed",
235
+ "message": "An unexpected error occurred. Please contact support.",
236
+ },
237
+ )
238
 
239
 
240
  if __name__ == "__main__":
241
+ import uvicorn
242
+
243
+ print("🚀 Lancement de l'API en mode développement...")
244
+ print("📖 Documentation : http://localhost:8000/docs")
245
+
246
+ uvicorn.run(
247
+ "app:app",
248
+ host="0.0.0.0",
249
+ port=8000,
250
+ reload=True,
251
+ log_level="info",
252
+ )
requirements.txt CHANGED
@@ -1,9 +1,103 @@
1
- # Minimal requirements for HF Spaces deployment
2
- # Only the dependencies needed for app.py and model loading
3
- gradio>=5.9.0
4
- huggingface-hub>=0.27.0
5
- joblib>=1.4.0
6
- scikit-learn>=1.6.0
7
- imbalanced-learn>=0.13.0
8
- xgboost>=2.1.0
9
- numpy>=2.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alembic==1.17.2 ; python_version >= "3.12" and python_version < "4.0"
2
+ annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
3
+ anyio==4.12.0 ; python_version >= "3.12" and python_version < "4.0"
4
+ blinker==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
5
+ cachetools==6.2.4 ; python_version >= "3.12" and python_version < "4.0"
6
+ certifi==2025.11.12 ; python_version >= "3.12" and python_version < "4.0"
7
+ cffi==2.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy"
8
+ charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
9
+ click==8.3.1 ; python_version >= "3.12" and python_version < "4.0"
10
+ cloudpickle==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
11
+ colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
12
+ contourpy==1.3.3 ; python_version >= "3.12" and python_version < "4.0"
13
+ cryptography==46.0.3 ; python_version >= "3.12" and python_version < "4.0"
14
+ cycler==0.12.1 ; python_version >= "3.12" and python_version < "4.0"
15
+ databricks-sdk==0.76.0 ; python_version >= "3.12" and python_version < "4.0"
16
+ deprecated==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
17
+ docker==7.1.0 ; python_version >= "3.12" and python_version < "4.0"
18
+ fastapi==0.115.14 ; python_version >= "3.12" and python_version < "4.0"
19
+ filelock==3.20.1 ; python_version >= "3.12" and python_version < "4.0"
20
+ flask-cors==6.0.2 ; python_version >= "3.12" and python_version < "4.0"
21
+ flask==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
22
+ fonttools==4.61.1 ; python_version >= "3.12" and python_version < "4.0"
23
+ fsspec==2025.12.0 ; python_version >= "3.12" and python_version < "4.0"
24
+ gitdb==4.0.12 ; python_version >= "3.12" and python_version < "4.0"
25
+ gitpython==3.1.45 ; python_version >= "3.12" and python_version < "4.0"
26
+ google-auth==2.45.0 ; python_version >= "3.12" and python_version < "4.0"
27
+ graphene==3.4.3 ; python_version >= "3.12" and python_version < "4.0"
28
+ graphql-core==3.2.7 ; python_version >= "3.12" and python_version < "4.0"
29
+ graphql-relay==3.2.0 ; python_version >= "3.12" and python_version < "4.0"
30
+ greenlet==3.3.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
31
+ gunicorn==23.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system != "Windows"
32
+ h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
33
+ hf-xet==1.2.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "arm64" or platform_machine == "aarch64")
34
+ httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
35
+ httptools==0.7.1 ; python_version >= "3.12" and python_version < "4.0"
36
+ httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
37
+ huey==2.5.5 ; python_version >= "3.12" and python_version < "4.0"
38
+ huggingface-hub==1.2.3 ; python_version >= "3.12" and python_version < "4.0"
39
+ idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
40
+ imbalanced-learn==0.13.0 ; python_version >= "3.12" and python_version < "4.0"
41
+ importlib-metadata==8.7.1 ; python_version >= "3.12" and python_version < "4.0"
42
+ itsdangerous==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
43
+ jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
44
+ joblib==1.5.3 ; python_version >= "3.12" and python_version < "4.0"
45
+ kiwisolver==1.4.9 ; python_version >= "3.12" and python_version < "4.0"
46
+ limits==5.6.0 ; python_version >= "3.12" and python_version < "4.0"
47
+ mako==1.3.10 ; python_version >= "3.12" and python_version < "4.0"
48
+ markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
49
+ matplotlib==3.10.8 ; python_version >= "3.12" and python_version < "4.0"
50
+ mlflow-skinny==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
51
+ mlflow-tracing==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
52
+ mlflow==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
53
+ numpy==2.4.0 ; python_version >= "3.12" and python_version < "4.0"
54
+ nvidia-nccl-cu12==2.28.9 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine != "aarch64"
55
+ opentelemetry-api==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
56
+ opentelemetry-proto==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
57
+ opentelemetry-sdk==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
58
+ opentelemetry-semantic-conventions==0.60b1 ; python_version >= "3.12" and python_version < "4.0"
59
+ packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
60
+ pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
61
+ pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
62
+ protobuf==6.33.2 ; python_version >= "3.12" and python_version < "4.0"
63
+ pyarrow==22.0.0 ; python_version >= "3.12" and python_version < "4.0"
64
+ pyasn1-modules==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
65
+ pyasn1==0.6.1 ; python_version >= "3.12" and python_version < "4.0"
66
+ pycparser==2.23 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy" and implementation_name != "PyPy"
67
+ pydantic-core==2.41.5 ; python_version >= "3.12" and python_version < "4.0"
68
+ pydantic==2.12.5 ; python_version >= "3.12" and python_version < "4.0"
69
+ pyparsing==3.3.1 ; python_version >= "3.12" and python_version < "4.0"
70
+ python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
71
+ python-dotenv==1.2.1 ; python_version >= "3.12" and python_version < "4.0"
72
+ python-json-logger==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
73
+ pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
74
+ pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
75
+ pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
76
+ requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
77
+ rsa==4.9.1 ; python_version >= "3.12" and python_version < "4.0"
78
+ scikit-learn==1.6.1 ; python_version >= "3.12" and python_version < "4.0"
79
+ scipy==1.16.3 ; python_version >= "3.12" and python_version < "4.0"
80
+ shellingham==1.5.4 ; python_version >= "3.12" and python_version < "4.0"
81
+ six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
82
+ sklearn-compat==0.1.5 ; python_version >= "3.12" and python_version < "4.0"
83
+ slowapi==0.1.9 ; python_version >= "3.12" and python_version < "4.0"
84
+ smmap==5.0.2 ; python_version >= "3.12" and python_version < "4.0"
85
+ sqlalchemy==2.0.45 ; python_version >= "3.12" and python_version < "4.0"
86
+ sqlparse==0.5.5 ; python_version >= "3.12" and python_version < "4.0"
87
+ starlette==0.46.2 ; python_version >= "3.12" and python_version < "4.0"
88
+ threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
89
+ tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
90
+ typer-slim==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
91
+ typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
92
+ typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
93
+ tzdata==2025.3 ; python_version >= "3.12" and python_version < "4.0"
94
+ urllib3==2.6.2 ; python_version >= "3.12" and python_version < "4.0"
95
+ uvicorn==0.32.1 ; python_version >= "3.12" and python_version < "4.0"
96
+ uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy"
97
+ waitress==3.0.2 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
98
+ watchfiles==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
99
+ websockets==15.0.1 ; python_version >= "3.12" and python_version < "4.0"
100
+ werkzeug==3.1.4 ; python_version >= "3.12" and python_version < "4.0"
101
+ wrapt==2.0.1 ; python_version >= "3.12" and python_version < "4.0"
102
+ xgboost==2.1.4 ; python_version >= "3.12" and python_version < "4.0"
103
+ zipp==3.23.0 ; python_version >= "3.12" and python_version < "4.0"
src/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Module src pour l'API FastAPI."""
src/auth.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module d'authentification pour l'API.
4
+
5
+ Fournit un système de vérification de clé API via header HTTP.
6
+ """
7
+ from fastapi import Header, HTTPException, status
8
+ from fastapi.security import APIKeyHeader
9
+
10
+ from src.config import get_settings
11
+
12
+ # Schéma pour la documentation Swagger
13
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
14
+
15
+
16
+ async def verify_api_key(x_api_key: str = Header(None)) -> str:
17
+ """
18
+ Vérifie que la clé API fournie est valide.
19
+
20
+ Cette fonction est utilisée comme dépendance FastAPI (Depends).
21
+ Elle vérifie le header HTTP "X-API-Key" et compare avec la clé configurée.
22
+
23
+ Args:
24
+ x_api_key: Clé API fournie dans le header HTTP.
25
+
26
+ Returns:
27
+ str: La clé API validée.
28
+
29
+ Raises:
30
+ HTTPException: 401 si la clé est manquante ou invalide.
31
+
32
+ Comment ça marche :
33
+ 1. FastAPI extrait automatiquement le header "X-API-Key"
34
+ 2. La fonction compare avec la clé configurée dans .env
35
+ 3. Si valide → continue, sinon → erreur 401
36
+
37
+ Exemple d'utilisation :
38
+ ```python
39
+ @app.post("/predict", dependencies=[Depends(verify_api_key)])
40
+ async def predict(...):
41
+ # Cette route est protégée !
42
+ ```
43
+
44
+ Exemple de requête curl :
45
+ ```bash
46
+ curl -X POST http://localhost:8000/predict \\
47
+ -H "X-API-Key: your-secret-key" \\
48
+ -H "Content-Type: application/json" \\
49
+ -d '{...}'
50
+ ```
51
+ """
52
+ settings = get_settings()
53
+
54
+ # En mode DEBUG, on peut désactiver l'auth
55
+ if settings.DEBUG:
56
+ return "debug-mode-no-auth-required"
57
+
58
+ # Vérifier que la clé est fournie
59
+ if not x_api_key:
60
+ raise HTTPException(
61
+ status_code=status.HTTP_401_UNAUTHORIZED,
62
+ detail={
63
+ "error": "API Key missing",
64
+ "message": "Le header 'X-API-Key' est requis pour accéder à cette ressource",
65
+ "solution": "Ajoutez le header: -H 'X-API-Key: votre-cle-api'",
66
+ },
67
+ headers={"WWW-Authenticate": "ApiKey"},
68
+ )
69
+
70
+ # Vérifier que la clé est correcte
71
+ if x_api_key != settings.API_KEY:
72
+ raise HTTPException(
73
+ status_code=status.HTTP_401_UNAUTHORIZED,
74
+ detail={
75
+ "error": "Invalid API Key",
76
+ "message": "La clé API fournie est invalide",
77
+ "solution": "Vérifiez votre clé API ou contactez l'administrateur",
78
+ },
79
+ headers={"WWW-Authenticate": "ApiKey"},
80
+ )
81
+
82
+ return x_api_key
83
+
84
+
85
+ def get_api_key_dependency():
86
+ """
87
+ Retourne la dépendance d'authentification si nécessaire.
88
+
89
+ Permet de conditionner l'authentification selon la config.
90
+
91
+ Returns:
92
+ Depends(verify_api_key) si auth requise, None sinon.
93
+ """
94
+ settings = get_settings()
95
+ if settings.is_api_key_required:
96
+ from fastapi import Depends
97
+
98
+ return Depends(verify_api_key)
99
+ return None
src/config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de configuration de l'application.
4
+
5
+ Charge les variables d'environnement depuis .env et fournit
6
+ une interface pour accéder à la configuration de manière sécurisée.
7
+ """
8
+ import os
9
+ from functools import lru_cache
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ # Charger .env au démarrage du module
14
+ load_dotenv()
15
+
16
+
17
+ class Settings:
18
+ """
19
+ Configuration de l'application.
20
+
21
+ Toutes les valeurs sensibles (API keys, etc.) sont chargées depuis
22
+ les variables d'environnement ou le fichier .env.
23
+ """
24
+
25
+ # ===== SÉCURITÉ =====
26
+ API_KEY: str = os.getenv("API_KEY", "dev-key-change-me-in-production")
27
+
28
+ # ===== API =====
29
+ API_VERSION: str = os.getenv("API_VERSION", "1.0.0")
30
+ API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
31
+ API_PORT: int = int(os.getenv("API_PORT", "8000"))
32
+
33
+ # ===== MODÈLE =====
34
+ HF_MODEL_REPO: str = os.getenv(
35
+ "HF_MODEL_REPO", "ASI-Engineer/employee-turnover-model"
36
+ )
37
+ MODEL_FILENAME: str = os.getenv("MODEL_FILENAME", "model/model.pkl")
38
+
39
+ # ===== ENVIRONNEMENT =====
40
+ DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
41
+ LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
42
+
43
+ @property
44
+ def is_api_key_required(self) -> bool:
45
+ """
46
+ Vérifie si l'API key est requise.
47
+
48
+ Returns:
49
+ False en mode DEBUG, True en production.
50
+ """
51
+ return not self.DEBUG
52
+
53
+
54
+ @lru_cache()
55
+ def get_settings() -> Settings:
56
+ """
57
+ Retourne l'instance singleton des settings.
58
+
59
+ Le décorateur @lru_cache() assure qu'on ne crée qu'une seule instance.
60
+
61
+ Returns:
62
+ Settings: Configuration de l'application.
63
+ """
64
+ return Settings()
src/logger.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de logging structuré pour l'API Employee Turnover.
4
+
5
+ Fournit un système de logging centralisé avec :
6
+ - Logs structurés en JSON
7
+ - Rotation automatique des fichiers
8
+ - Niveaux de log configurables
9
+ - Intégration FastAPI
10
+ """
11
+ import logging
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Any, Dict
15
+
16
+ from pythonjsonlogger import jsonlogger
17
+
18
+ from src.config import get_settings
19
+
20
+ settings = get_settings()
21
+
22
+ # Créer le dossier logs s'il n'existe pas
23
+ LOG_DIR = Path("logs")
24
+ LOG_DIR.mkdir(exist_ok=True)
25
+
26
+ # Fichiers de logs
27
+ LOG_FILE = LOG_DIR / "api.log"
28
+ ERROR_LOG_FILE = LOG_DIR / "error.log"
29
+
30
+
31
+ class CustomJsonFormatter(jsonlogger.JsonFormatter):
32
+ """
33
+ Formatter JSON personnalisé avec champs supplémentaires.
34
+ """
35
+
36
+ def add_fields(
37
+ self,
38
+ log_record: Dict[str, Any],
39
+ record: logging.LogRecord,
40
+ message_dict: Dict[str, Any],
41
+ ) -> None:
42
+ """
43
+ Ajoute des champs personnalisés aux logs JSON.
44
+ """
45
+ super().add_fields(log_record, record, message_dict)
46
+
47
+ # Ajouter des métadonnées
48
+ log_record["level"] = record.levelname
49
+ log_record["logger"] = record.name
50
+ log_record["module"] = record.module
51
+ log_record["function"] = record.funcName
52
+ log_record["line"] = record.lineno
53
+
54
+ # Timestamp ISO 8601
55
+ if not log_record.get("timestamp"):
56
+ log_record["timestamp"] = self.formatTime(record, self.datefmt)
57
+
58
+
59
+ def setup_logger(name: str = "employee_turnover_api") -> logging.Logger:
60
+ """
61
+ Configure et retourne un logger structuré.
62
+
63
+ Args:
64
+ name: Nom du logger.
65
+
66
+ Returns:
67
+ Logger configuré avec handlers console et fichiers.
68
+
69
+ Examples:
70
+ >>> logger = setup_logger()
71
+ >>> logger.info("API démarrée", extra={"version": "2.0.0"})
72
+ """
73
+ logger = logging.getLogger(name)
74
+
75
+ # Éviter duplication si déjà configuré
76
+ if logger.handlers:
77
+ return logger
78
+
79
+ # Niveau de log depuis configuration
80
+ log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
81
+ logger.setLevel(log_level)
82
+
83
+ # === HANDLER CONSOLE (stdout) ===
84
+ console_handler = logging.StreamHandler(sys.stdout)
85
+ console_handler.setLevel(log_level)
86
+
87
+ # Format simple pour la console en dev, JSON en prod
88
+ if settings.DEBUG:
89
+ console_format = logging.Formatter(
90
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
91
+ datefmt="%Y-%m-%d %H:%M:%S",
92
+ )
93
+ else:
94
+ console_format = CustomJsonFormatter(
95
+ "%(timestamp)s %(level)s %(name)s %(message)s"
96
+ )
97
+
98
+ console_handler.setFormatter(console_format)
99
+ logger.addHandler(console_handler)
100
+
101
+ # === HANDLER FICHIER (tous les logs) ===
102
+ file_handler = logging.FileHandler(LOG_FILE, encoding="utf-8")
103
+ file_handler.setLevel(log_level)
104
+ file_handler.setFormatter(
105
+ CustomJsonFormatter("%(timestamp)s %(level)s %(name)s %(message)s")
106
+ )
107
+ logger.addHandler(file_handler)
108
+
109
+ # === HANDLER ERREURS UNIQUEMENT ===
110
+ error_handler = logging.FileHandler(ERROR_LOG_FILE, encoding="utf-8")
111
+ error_handler.setLevel(logging.ERROR)
112
+ error_handler.setFormatter(
113
+ CustomJsonFormatter("%(timestamp)s %(level)s %(name)s %(message)s")
114
+ )
115
+ logger.addHandler(error_handler)
116
+
117
+ # Éviter propagation au root logger
118
+ logger.propagate = False
119
+
120
+ return logger
121
+
122
+
123
+ def log_request(
124
+ method: str,
125
+ path: str,
126
+ status_code: int,
127
+ duration_ms: float,
128
+ **kwargs: Any,
129
+ ) -> None:
130
+ """
131
+ Log une requête HTTP avec métadonnées.
132
+
133
+ Args:
134
+ method: Méthode HTTP (GET, POST...).
135
+ path: Chemin de l'endpoint.
136
+ status_code: Code de statut HTTP.
137
+ duration_ms: Durée de la requête en millisecondes.
138
+ **kwargs: Métadonnées additionnelles.
139
+
140
+ Examples:
141
+ >>> log_request("POST", "/predict", 200, 45.3, user_id="123")
142
+ """
143
+ logger = logging.getLogger("employee_turnover_api")
144
+
145
+ log_data = {
146
+ "method": method,
147
+ "path": path,
148
+ "status_code": status_code,
149
+ "duration_ms": round(duration_ms, 2),
150
+ **kwargs,
151
+ }
152
+
153
+ # Niveau selon status code
154
+ if status_code >= 500:
155
+ logger.error(f"Request {method} {path}", extra=log_data)
156
+ elif status_code >= 400:
157
+ logger.warning(f"Request {method} {path}", extra=log_data)
158
+ else:
159
+ logger.info(f"Request {method} {path}", extra=log_data)
160
+
161
+
162
+ def log_prediction(
163
+ employee_id: str | None,
164
+ prediction: int,
165
+ probability: float,
166
+ risk_level: str,
167
+ duration_ms: float,
168
+ ) -> None:
169
+ """
170
+ Log une prédiction effectuée.
171
+
172
+ Args:
173
+ employee_id: ID de l'employé (optionnel).
174
+ prediction: Prédiction (0 ou 1).
175
+ probability: Probabilité de turnover.
176
+ risk_level: Niveau de risque ("low", "medium", "high").
177
+ duration_ms: Durée du preprocessing + pr��diction.
178
+
179
+ Examples:
180
+ >>> log_prediction("EMP123", 1, 0.87, "high", 23.4)
181
+ """
182
+ logger = logging.getLogger("employee_turnover_api")
183
+
184
+ logger.info(
185
+ "Prediction made",
186
+ extra={
187
+ "employee_id": employee_id,
188
+ "prediction": prediction,
189
+ "probability": round(probability, 4),
190
+ "risk_level": risk_level,
191
+ "duration_ms": round(duration_ms, 2),
192
+ },
193
+ )
194
+
195
+
196
+ def log_model_load(model_type: str, duration_ms: float, success: bool) -> None:
197
+ """
198
+ Log le chargement du modèle.
199
+
200
+ Args:
201
+ model_type: Type de modèle chargé.
202
+ duration_ms: Durée du chargement.
203
+ success: Si le chargement a réussi.
204
+
205
+ Examples:
206
+ >>> log_model_load("XGBoost Pipeline", 1234.5, True)
207
+ """
208
+ logger = logging.getLogger("employee_turnover_api")
209
+
210
+ log_data = {
211
+ "model_type": model_type,
212
+ "duration_ms": round(duration_ms, 2),
213
+ "success": success,
214
+ }
215
+
216
+ if success:
217
+ logger.info("Model loaded successfully", extra=log_data)
218
+ else:
219
+ logger.error("Model loading failed", extra=log_data)
220
+
221
+
222
+ # Créer le logger global
223
+ logger = setup_logger()
src/models.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de chargement et gestion du modèle MLflow.
4
+
5
+ Ce module encapsule la logique de chargement du modèle depuis Hugging Face Hub
6
+ via MLflow, avec gestion des erreurs et versioning.
7
+ """
8
+ from typing import Any, Optional
9
+
10
+ from fastapi import HTTPException
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Configuration
14
+ HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model"
15
+ MODEL_FILENAME = "model/model.pkl"
16
+
17
+ # Cache global du modèle
18
+ _model_cache: Optional[Any] = None
19
+
20
+
21
+ def load_model(force_reload: bool = False) -> Any:
22
+ """
23
+ Charge le modèle depuis Hugging Face Hub via MLflow.
24
+
25
+ Cette fonction implémente un système de cache pour éviter de recharger
26
+ le modèle à chaque appel. Le modèle est chargé une seule fois au démarrage
27
+ de l'application et mis en cache.
28
+
29
+ Args:
30
+ force_reload: Si True, force le rechargement du modèle même s'il est en cache.
31
+
32
+ Returns:
33
+ Le modèle MLflow chargé et prêt pour l'inférence.
34
+
35
+ Raises:
36
+ HTTPException: 500 si le modèle ne peut pas être chargé.
37
+
38
+ Examples:
39
+ >>> model = load_model()
40
+ >>> # Utiliser le modèle pour prédiction
41
+ >>> predictions = model.predict(X)
42
+ """
43
+ global _model_cache
44
+
45
+ # Retourner le modèle en cache si disponible
46
+ if _model_cache is not None and not force_reload:
47
+ return _model_cache
48
+
49
+ try:
50
+ import joblib
51
+
52
+ print(f"🔄 Chargement du modèle depuis HF Hub: {HF_MODEL_REPO}")
53
+
54
+ # Télécharger le modèle depuis Hugging Face Hub
55
+ model_path = hf_hub_download(
56
+ repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, repo_type="model"
57
+ )
58
+
59
+ print(f"📦 Modèle téléchargé: {model_path}")
60
+
61
+ # Charger le modèle avec joblib
62
+ model = joblib.load(model_path)
63
+
64
+ # Mettre en cache
65
+ _model_cache = model
66
+
67
+ print(f"✅ Modèle chargé avec succès: {type(model).__name__}")
68
+ return model
69
+
70
+ except Exception as e:
71
+ error_msg = f"❌ Erreur lors du chargement du modèle: {str(e)}"
72
+ print(error_msg)
73
+ raise HTTPException(
74
+ status_code=500,
75
+ detail={
76
+ "error": "Model loading failed",
77
+ "message": str(e),
78
+ "model_repo": HF_MODEL_REPO,
79
+ "solution": "Vérifiez que le modèle est disponible sur HF Hub et correctement entraîné",
80
+ },
81
+ )
82
+
83
+
84
+ def get_model_info() -> dict:
85
+ """
86
+ Retourne les informations sur le modèle chargé.
87
+
88
+ Returns:
89
+ Dict contenant les métadonnées du modèle.
90
+
91
+ Raises:
92
+ HTTPException: 500 si le modèle n'est pas chargé.
93
+ """
94
+ try:
95
+ model = load_model()
96
+
97
+ return {
98
+ "status": "✅ Modèle chargé",
99
+ "model_type": type(model).__name__,
100
+ "hf_hub_repo": HF_MODEL_REPO,
101
+ "model_file": MODEL_FILENAME,
102
+ "cached": _model_cache is not None,
103
+ }
104
+
105
+ except Exception as e:
106
+ raise HTTPException(
107
+ status_code=500,
108
+ detail={"error": "Model info unavailable", "message": str(e)},
109
+ )
110
+
111
+
112
+ def load_preprocessing_artifacts(run_id: str) -> dict:
113
+ """
114
+ Charge les artifacts de preprocessing (scaler, encoders) depuis MLflow.
115
+
116
+ Args:
117
+ run_id: ID du run MLflow contenant les artifacts.
118
+
119
+ Returns:
120
+ Dict contenant les artifacts de preprocessing.
121
+
122
+ Raises:
123
+ HTTPException: 500 si les artifacts ne peuvent pas être chargés.
124
+
125
+ Note:
126
+ Cette fonction sera implémentée quand les preprocessing artifacts
127
+ seront disponibles dans le modèle HF Hub.
128
+ """
129
+ raise NotImplementedError(
130
+ "Le chargement des preprocessing artifacts sera implémenté "
131
+ "lors de l'intégration complète avec MLflow"
132
+ )
133
+
134
+
135
+ if __name__ == "__main__":
136
+ # Test de chargement du modèle
137
+ print("=" * 80)
138
+ print("TEST DE CHARGEMENT DU MODÈLE")
139
+ print("=" * 80)
140
+
141
+ try:
142
+ model = load_model()
143
+ print("\n✅ Test réussi!")
144
+ print(f"Type de modèle: {type(model).__name__}")
145
+
146
+ # Afficher les infos
147
+ info = get_model_info()
148
+ print("\nInformations du modèle:")
149
+ for key, value in info.items():
150
+ print(f" {key}: {value}")
151
+
152
+ except Exception as e:
153
+ print(f"\n❌ Test échoué: {e}")
src/preprocessing.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de preprocessing pour transformer les données d'entrée avant prédiction.
4
+
5
+ Ce module applique les mêmes transformations que le pipeline d'entraînement :
6
+ - Feature engineering (ratios, moyennes)
7
+ - Encoding (OneHot, Ordinal)
8
+ - Scaling (StandardScaler)
9
+ """
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
13
+
14
+ from src.schemas import EmployeeInput
15
+
16
+
17
+ def create_input_dataframe(employee: EmployeeInput) -> pd.DataFrame:
18
+ """
19
+ Convertit un objet EmployeeInput Pydantic en DataFrame pandas.
20
+
21
+ Args:
22
+ employee: Données validées d'un employé.
23
+
24
+ Returns:
25
+ DataFrame avec une seule ligne contenant toutes les features.
26
+ """
27
+ data = {
28
+ # SONDAGE
29
+ "nombre_participation_pee": [employee.nombre_participation_pee],
30
+ "nb_formations_suivies": [employee.nb_formations_suivies],
31
+ "nombre_employee_sous_responsabilite": [
32
+ employee.nombre_employee_sous_responsabilite
33
+ ],
34
+ "distance_domicile_travail": [employee.distance_domicile_travail],
35
+ "niveau_education": [employee.niveau_education],
36
+ "domaine_etude": [employee.domaine_etude],
37
+ "ayant_enfants": [employee.ayant_enfants],
38
+ "frequence_deplacement": [employee.frequence_deplacement],
39
+ "annees_depuis_la_derniere_promotion": [
40
+ employee.annees_depuis_la_derniere_promotion
41
+ ],
42
+ "annes_sous_responsable_actuel": [employee.annes_sous_responsable_actuel],
43
+ # EVALUATION
44
+ "satisfaction_employee_environnement": [
45
+ employee.satisfaction_employee_environnement
46
+ ],
47
+ "note_evaluation_precedente": [employee.note_evaluation_precedente],
48
+ "niveau_hierarchique_poste": [employee.niveau_hierarchique_poste],
49
+ "satisfaction_employee_nature_travail": [
50
+ employee.satisfaction_employee_nature_travail
51
+ ],
52
+ "satisfaction_employee_equipe": [employee.satisfaction_employee_equipe],
53
+ "satisfaction_employee_equilibre_pro_perso": [
54
+ employee.satisfaction_employee_equilibre_pro_perso
55
+ ],
56
+ "note_evaluation_actuelle": [employee.note_evaluation_actuelle],
57
+ "heure_supplementaires": [employee.heure_supplementaires],
58
+ "augementation_salaire_precedente": [employee.augementation_salaire_precedente],
59
+ # SIRH
60
+ "age": [employee.age],
61
+ "genre": [employee.genre],
62
+ "revenu_mensuel": [employee.revenu_mensuel],
63
+ "statut_marital": [employee.statut_marital],
64
+ "departement": [employee.departement],
65
+ "poste": [employee.poste],
66
+ "nombre_experiences_precedentes": [employee.nombre_experiences_precedentes],
67
+ "nombre_heures_travailless": [employee.nombre_heures_travailless],
68
+ "annee_experience_totale": [employee.annee_experience_totale],
69
+ "annees_dans_l_entreprise": [employee.annees_dans_l_entreprise],
70
+ "annees_dans_le_poste_actuel": [employee.annees_dans_le_poste_actuel],
71
+ }
72
+
73
+ return pd.DataFrame(data)
74
+
75
+
76
+ def engineer_features(df: pd.DataFrame) -> pd.DataFrame:
77
+ """
78
+ Applique le feature engineering (mêmes transformations que l'entraînement).
79
+
80
+ Args:
81
+ df: DataFrame avec les colonnes brutes.
82
+
83
+ Returns:
84
+ DataFrame avec les features engineered ajoutées.
85
+ """
86
+ df = df.copy()
87
+
88
+ # Ratios (+ 1 pour éviter division par zéro)
89
+ df["revenu_par_anciennete"] = df["revenu_mensuel"] / (
90
+ df["annees_dans_l_entreprise"] + 1
91
+ )
92
+ df["experience_par_anciennete"] = df["annee_experience_totale"] / (
93
+ df["annees_dans_l_entreprise"] + 1
94
+ )
95
+ df["promo_par_anciennete"] = df["annees_depuis_la_derniere_promotion"] / (
96
+ df["annees_dans_l_entreprise"] + 1
97
+ )
98
+
99
+ # Moyenne de satisfaction
100
+ df["satisfaction_moyenne"] = df[
101
+ [
102
+ "satisfaction_employee_environnement",
103
+ "satisfaction_employee_nature_travail",
104
+ "satisfaction_employee_equipe",
105
+ "satisfaction_employee_equilibre_pro_perso",
106
+ ]
107
+ ].mean(axis=1)
108
+
109
+ return df
110
+
111
+
112
+ def encode_and_scale(df: pd.DataFrame) -> pd.DataFrame:
113
+ """
114
+ Encode les variables catégorielles et scale les numériques.
115
+ IMPORTANT: Doit correspondre EXACTEMENT au pipeline d'entraînement.
116
+
117
+ Args:
118
+ df: DataFrame avec features engineered.
119
+
120
+ Returns:
121
+ DataFrame transformé avec 50 colonnes (comme training).
122
+ """
123
+ df = df.copy()
124
+
125
+ # === ENCODING ===
126
+
127
+ # NOTE: ayant_enfants et heure_supplementaires sont SUPPRIMÉS
128
+ # (ne font pas partie des features du modèle d'entraînement)
129
+ cols_to_drop = ["ayant_enfants", "heure_supplementaires"]
130
+ df = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
131
+
132
+ # OneHot pour variables catégorielles non-ordonnées
133
+ # IMPORTANT: Utiliser les mêmes catégories que lors de l'entraînement
134
+ cat_non_ord = ["genre", "statut_marital", "departement", "poste", "domaine_etude"]
135
+
136
+ # Définir toutes les catégories possibles (depuis training data)
137
+ categories_dict = {
138
+ "genre": ["F", "M"],
139
+ "statut_marital": ["Célibataire", "Divorcé(e)", "Marié(e)"],
140
+ "departement": ["Commercial", "Consulting", "Ressources Humaines"],
141
+ "poste": [
142
+ "Assistant de Direction",
143
+ "Cadre Commercial",
144
+ "Consultant",
145
+ "Directeur Technique",
146
+ "Manager",
147
+ "Représentant Commercial",
148
+ "Ressources Humaines",
149
+ "Senior Manager",
150
+ "Tech Lead",
151
+ ],
152
+ "domaine_etude": [
153
+ "Autre",
154
+ "Entrepreunariat",
155
+ "Infra & Cloud",
156
+ "Marketing",
157
+ "Ressources Humaines",
158
+ "Transformation Digitale",
159
+ ],
160
+ }
161
+
162
+ onehot = OneHotEncoder(
163
+ sparse_output=False,
164
+ handle_unknown="ignore",
165
+ categories=[categories_dict[col] for col in cat_non_ord],
166
+ )
167
+
168
+ encoded_non_ord = pd.DataFrame(
169
+ onehot.fit_transform(df[cat_non_ord]),
170
+ columns=onehot.get_feature_names_out(cat_non_ord),
171
+ index=df.index,
172
+ )
173
+
174
+ # Ordinal pour fréquence déplacement
175
+ ordinal = OrdinalEncoder(categories=[["Aucun", "Occasionnel", "Frequent"]])
176
+ df["frequence_deplacement"] = ordinal.fit_transform(
177
+ df[["frequence_deplacement"]]
178
+ ).flatten()
179
+
180
+ # Supprimer les colonnes catégorielles originales
181
+ df = df.drop(columns=cat_non_ord)
182
+
183
+ # Concaténer les encodages OneHot
184
+ df = pd.concat([df, encoded_non_ord], axis=1)
185
+
186
+ # === SCALING ===
187
+
188
+ # Colonnes numériques à scaler
189
+ quantitative_cols = df.select_dtypes(include=[np.number]).columns.tolist()
190
+
191
+ # Retirer les colonnes OneHot du scaling (elles sont déjà 0/1)
192
+ cols_to_scale = [
193
+ col
194
+ for col in quantitative_cols
195
+ if df[col].nunique() > 2 # Exclut colonnes binaires (0/1)
196
+ ]
197
+
198
+ # Appliquer le scaling uniquement s'il y a des colonnes
199
+ if cols_to_scale:
200
+ scaler = StandardScaler()
201
+ df[cols_to_scale] = scaler.fit_transform(df[cols_to_scale])
202
+
203
+ return df
204
+
205
+
206
+ def preprocess_for_prediction(employee: EmployeeInput) -> np.ndarray:
207
+ """
208
+ Pipeline complet de preprocessing pour une prédiction.
209
+
210
+ Args:
211
+ employee: Données validées d'un employé.
212
+
213
+ Returns:
214
+ Array numpy transformé prêt pour model.predict().
215
+
216
+ Examples:
217
+ >>> from src.schemas import EmployeeInput
218
+ >>> employee = EmployeeInput(...)
219
+ >>> X = preprocess_for_prediction(employee)
220
+ >>> prediction = model.predict(X)
221
+ """
222
+ # 1. Créer DataFrame
223
+ df = create_input_dataframe(employee)
224
+
225
+ # 2. Feature engineering
226
+ df = engineer_features(df)
227
+
228
+ # 3. Encoding et scaling
229
+ df = encode_and_scale(df)
230
+
231
+ # 4. Convertir en numpy array (le modèle attend un array)
232
+ return df.values
233
+
234
+
235
+ # TODO: Implémenter le chargement des artifacts sauvegardés
236
+ # def load_preprocessing_artifacts(run_id: str) -> dict:
237
+ # """
238
+ # Charge les encoders et scaler depuis MLflow.
239
+ #
240
+ # Returns:
241
+ # dict avec keys: 'onehot_encoder', 'ordinal_encoder', 'scaler'
242
+ # """
243
+ # pass
src/rate_limit.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de rate limiting pour protéger l'API contre les abus.
4
+
5
+ Utilise SlowAPI pour limiter le nombre de requêtes par IP/utilisateur.
6
+ """
7
+ from slowapi import Limiter
8
+ from slowapi.util import get_remote_address
9
+
10
+ from src.config import get_settings
11
+
12
+ settings = get_settings()
13
+
14
+ # Créer le limiter avec stratégie par IP
15
+ limiter = Limiter(
16
+ key_func=get_remote_address,
17
+ default_limits=["100/minute"] if not settings.DEBUG else [],
18
+ storage_uri="memory://", # En production: utiliser Redis
19
+ strategy="fixed-window",
20
+ )
21
+
22
+
23
+ def get_rate_limit_key(request):
24
+ """
25
+ Fonction pour obtenir la clé de rate limiting.
26
+
27
+ En production, on pourrait utiliser l'API Key au lieu de l'IP.
28
+
29
+ Args:
30
+ request: Requête FastAPI.
31
+
32
+ Returns:
33
+ Clé unique pour identifier l'utilisateur.
34
+ """
35
+ # Priorité: API Key > IP
36
+ api_key = request.headers.get("X-API-Key")
37
+ if api_key:
38
+ return f"api_key:{api_key}"
39
+
40
+ return get_remote_address(request)
src/schemas.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Schémas Pydantic pour validation des données d'entrée de l'API.
4
+
5
+ Ces schémas correspondent aux colonnes brutes du dataset avant preprocessing,
6
+ permettant une validation stricte des inputs avec messages d'erreur clairs.
7
+ """
8
+ from enum import Enum
9
+ from typing import Literal
10
+
11
+ from pydantic import BaseModel, Field, field_validator
12
+
13
+
14
+ # Enums pour les valeurs catégorielles
15
+ class GenreEnum(str, Enum):
16
+ """Genre de l'employé."""
17
+
18
+ M = "M"
19
+ F = "F"
20
+
21
+
22
+ class StatutMaritalEnum(str, Enum):
23
+ """Statut marital de l'employé."""
24
+
25
+ CELIBATAIRE = "Célibataire"
26
+ MARIE = "Marié(e)"
27
+ DIVORCE = "Divorcé(e)"
28
+
29
+
30
+ class DepartementEnum(str, Enum):
31
+ """Département de l'employé."""
32
+
33
+ COMMERCIAL = "Commercial"
34
+ CONSULTING = "Consulting"
35
+
36
+
37
+ class DomaineEtudeEnum(str, Enum):
38
+ """Domaine d'études de l'employé."""
39
+
40
+ INFRA_CLOUD = "Infra & Cloud"
41
+ TRANSFORMATION_DIGITALE = "Transformation Digitale"
42
+ AUTRE = "Autre"
43
+
44
+
45
+ class FrequenceDeplacementEnum(str, Enum):
46
+ """Fréquence des déplacements professionnels."""
47
+
48
+ AUCUN = "Aucun"
49
+ OCCASIONNEL = "Occasionnel"
50
+ FREQUENT = "Frequent"
51
+
52
+
53
+ class EmployeeInput(BaseModel):
54
+ """
55
+ Schéma de validation pour les données d'entrée d'un employé.
56
+
57
+ Tous les champs correspondent aux colonnes brutes des 3 fichiers CSV
58
+ (sondage, eval, sirh) avant preprocessing.
59
+ """
60
+
61
+ # === Données SONDAGE ===
62
+ nombre_participation_pee: int = Field(
63
+ ..., ge=0, description="Nombre de participations au PEE"
64
+ )
65
+ nb_formations_suivies: int = Field(
66
+ ..., ge=0, le=10, description="Nombre de formations suivies"
67
+ )
68
+ nombre_employee_sous_responsabilite: int = Field(
69
+ ..., ge=0, description="Nombre d'employés sous responsabilité"
70
+ )
71
+ distance_domicile_travail: int = Field(
72
+ ..., ge=0, le=50, description="Distance domicile-travail en km"
73
+ )
74
+ niveau_education: int = Field(
75
+ ..., ge=1, le=5, description="Niveau d'éducation (1-5)"
76
+ )
77
+ domaine_etude: DomaineEtudeEnum = Field(..., description="Domaine d'études")
78
+ ayant_enfants: Literal["Y", "N"] = Field(..., description="A des enfants (Y/N)")
79
+ frequence_deplacement: FrequenceDeplacementEnum = Field(
80
+ ..., description="Fréquence des déplacements"
81
+ )
82
+ annees_depuis_la_derniere_promotion: int = Field(
83
+ ..., ge=0, description="Années depuis la dernière promotion"
84
+ )
85
+ annes_sous_responsable_actuel: int = Field(
86
+ ..., ge=0, description="Années sous le responsable actuel"
87
+ )
88
+
89
+ # === Données EVALUATION ===
90
+ satisfaction_employee_environnement: int = Field(
91
+ ..., ge=1, le=4, description="Satisfaction environnement (1-4)"
92
+ )
93
+ note_evaluation_precedente: int = Field(
94
+ ..., ge=1, le=5, description="Note évaluation précédente (1-5)"
95
+ )
96
+ niveau_hierarchique_poste: int = Field(
97
+ ..., ge=1, le=5, description="Niveau hiérarchique (1-5)"
98
+ )
99
+ satisfaction_employee_nature_travail: int = Field(
100
+ ..., ge=1, le=4, description="Satisfaction nature du travail (1-4)"
101
+ )
102
+ satisfaction_employee_equipe: int = Field(
103
+ ..., ge=1, le=4, description="Satisfaction équipe (1-4)"
104
+ )
105
+ satisfaction_employee_equilibre_pro_perso: int = Field(
106
+ ..., ge=1, le=4, description="Satisfaction équilibre pro/perso (1-4)"
107
+ )
108
+ note_evaluation_actuelle: int = Field(
109
+ ..., ge=1, le=5, description="Note évaluation actuelle (1-5)"
110
+ )
111
+ heure_supplementaires: Literal["Oui", "Non"] = Field(
112
+ ..., description="Fait des heures supplémentaires"
113
+ )
114
+ augementation_salaire_precedente: float = Field(
115
+ ..., ge=0, le=100, description="Augmentation salaire précédente (%)"
116
+ )
117
+
118
+ # === Données SIRH ===
119
+ age: int = Field(..., ge=18, le=70, description="Âge de l'employé")
120
+ genre: GenreEnum = Field(..., description="Genre")
121
+ revenu_mensuel: float = Field(..., ge=1000, description="Revenu mensuel (€)")
122
+ statut_marital: StatutMaritalEnum = Field(..., description="Statut marital")
123
+ departement: DepartementEnum = Field(..., description="Département")
124
+ poste: str = Field(..., min_length=3, description="Intitulé du poste")
125
+ nombre_experiences_precedentes: int = Field(
126
+ ..., ge=0, description="Nombre d'expériences précédentes"
127
+ )
128
+ nombre_heures_travailless: int = Field(
129
+ ..., ge=35, le=80, description="Nombre d'heures travaillées par semaine"
130
+ )
131
+ annee_experience_totale: int = Field(
132
+ ..., ge=0, description="Années d'expérience totale"
133
+ )
134
+ annees_dans_l_entreprise: int = Field(
135
+ ..., ge=0, description="Années dans l'entreprise"
136
+ )
137
+ annees_dans_le_poste_actuel: int = Field(
138
+ ..., ge=0, description="Années dans le poste actuel"
139
+ )
140
+
141
+ @field_validator("augementation_salaire_precedente")
142
+ @classmethod
143
+ def validate_augmentation(cls, v: float) -> float:
144
+ """Nettoie le format de l'augmentation (enlève % si présent)."""
145
+ if isinstance(v, str):
146
+ v = float(v.replace(" %", "").replace("%", ""))
147
+ return v
148
+
149
+ class Config:
150
+ """Configuration Pydantic."""
151
+
152
+ json_schema_extra = {
153
+ "example": {
154
+ # Exemple basé sur la première ligne des CSV
155
+ "nombre_participation_pee": 0,
156
+ "nb_formations_suivies": 0,
157
+ "nombre_employee_sous_responsabilite": 1,
158
+ "distance_domicile_travail": 1,
159
+ "niveau_education": 2,
160
+ "domaine_etude": "Infra & Cloud",
161
+ "ayant_enfants": "Y",
162
+ "frequence_deplacement": "Occasionnel",
163
+ "annees_depuis_la_derniere_promotion": 0,
164
+ "annes_sous_responsable_actuel": 5,
165
+ "satisfaction_employee_environnement": 2,
166
+ "note_evaluation_precedente": 3,
167
+ "niveau_hierarchique_poste": 2,
168
+ "satisfaction_employee_nature_travail": 4,
169
+ "satisfaction_employee_equipe": 1,
170
+ "satisfaction_employee_equilibre_pro_perso": 1,
171
+ "note_evaluation_actuelle": 3,
172
+ "heure_supplementaires": "Oui",
173
+ "augementation_salaire_precedente": 11.0,
174
+ "age": 41,
175
+ "genre": "F",
176
+ "revenu_mensuel": 5993.0,
177
+ "statut_marital": "Célibataire",
178
+ "departement": "Commercial",
179
+ "poste": "Cadre Commercial",
180
+ "nombre_experiences_precedentes": 8,
181
+ "nombre_heures_travailless": 80,
182
+ "annee_experience_totale": 8,
183
+ "annees_dans_l_entreprise": 6,
184
+ "annees_dans_le_poste_actuel": 4,
185
+ }
186
+ }
187
+
188
+
189
+ class PredictionOutput(BaseModel):
190
+ """Schéma de sortie pour les prédictions."""
191
+
192
+ prediction: int = Field(..., description="Classe prédite (0=reste, 1=part)")
193
+ probability_0: float = Field(
194
+ ..., ge=0, le=1, description="Probabilité de rester (classe 0)"
195
+ )
196
+ probability_1: float = Field(
197
+ ..., ge=0, le=1, description="Probabilité de partir (classe 1)"
198
+ )
199
+ risk_level: str = Field(..., description="Niveau de risque (Low/Medium/High)")
200
+
201
+ class Config:
202
+ """Configuration Pydantic."""
203
+
204
+ json_schema_extra = {
205
+ "example": {
206
+ "prediction": 1,
207
+ "probability_0": 0.35,
208
+ "probability_1": 0.65,
209
+ "risk_level": "High",
210
+ }
211
+ }
212
+
213
+
214
+ class HealthCheck(BaseModel):
215
+ """Schéma pour le endpoint health check."""
216
+
217
+ status: str = Field(..., description="Status de l'API")
218
+ model_loaded: bool = Field(..., description="Modèle chargé ou non")
219
+ model_type: str = Field(..., description="Type du modèle")
220
+ version: str = Field(..., description="Version de l'API")
221
+
222
+ class Config:
223
+ """Configuration Pydantic."""
224
+
225
+ json_schema_extra = {
226
+ "example": {
227
+ "status": "healthy",
228
+ "model_loaded": True,
229
+ "model_type": "Pipeline",
230
+ "version": "1.0.0",
231
+ }
232
+ }