Meet2304 commited on
Commit
f255e67
·
verified ·
1 Parent(s): 9ff89e9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +450 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Project Phoenix - Cervical Cancer Cell Classification API
3
+ Flask application for running inference on ConvNeXt V2 model from Hugging Face
4
+ with explainability features (GRAD-CAM).
5
+ """
6
+
7
+ import os
8
+ import io
9
+ import base64
10
+ import numpy as np
11
+ import cv2
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ # Flask
16
+ from flask import Flask, request, jsonify
17
+ from flask_cors import CORS
18
+ from werkzeug.utils import secure_filename
19
+
20
+ # Deep Learning
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from PIL import Image
25
+ import torchvision.transforms as T
26
+
27
+ # Transformers
28
+ from transformers import (
29
+ ConvNextV2ForImageClassification,
30
+ AutoImageProcessor
31
+ )
32
+
33
+ # GRAD-CAM
34
+ from pytorch_grad_cam import GradCAM
35
+ from pytorch_grad_cam.utils.image import show_cam_on_image
36
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
37
+
38
+ # ========== CONFIGURATION ==========
39
+
40
+ # Update this with your Hugging Face model ID
41
+ # Example: "Meet2304/convnextv2-cervical-cell-classification"
42
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "Meet2304/convnextv2-cervical-cell-classification")
43
+
44
+ # Class names
45
+ CLASS_NAMES = [
46
+ 'im_Dyskeratotic',
47
+ 'im_Koilocytotic',
48
+ 'im_Metaplastic',
49
+ 'im_Parabasal',
50
+ 'im_Superficial-Intermediate'
51
+ ]
52
+
53
+ # Display names (cleaner for UI)
54
+ DISPLAY_NAMES = [
55
+ 'Dyskeratotic',
56
+ 'Koilocytotic',
57
+ 'Metaplastic',
58
+ 'Parabasal',
59
+ 'Superficial-Intermediate'
60
+ ]
61
+
62
+ # Image preprocessing
63
+ IMG_SIZE = 224
64
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'bmp'}
65
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
66
+
67
+ # Device
68
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+
70
+ # ========== FLASK APP SETUP ==========
71
+
72
+ app = Flask(__name__)
73
+ CORS(app) # Enable CORS for Next.js frontend
74
+
75
+ app.config['MAX_CONTENT_LENGTH'] = MAX_FILE_SIZE
76
+
77
+ # ========== MODEL LOADING ==========
78
+
79
+ print("Loading model from Hugging Face...")
80
+ print(f"Model ID: {HF_MODEL_ID}")
81
+ print(f"Device: {DEVICE}")
82
+
83
+ # Load image processor
84
+ processor = AutoImageProcessor.from_pretrained(HF_MODEL_ID)
85
+ print("✓ Processor loaded")
86
+
87
+ # Load model
88
+ model = ConvNextV2ForImageClassification.from_pretrained(HF_MODEL_ID)
89
+ model = model.to(DEVICE)
90
+ model.eval()
91
+ print("✓ Model loaded and set to evaluation mode")
92
+
93
+ print(f"Model configuration:")
94
+ print(f" - Number of classes: {model.config.num_labels}")
95
+ print(f" - Image size: {model.config.image_size}")
96
+ print(f" - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
97
+
98
+ # ========== HELPER FUNCTIONS ==========
99
+
100
+ def allowed_file(filename: str) -> bool:
101
+ """Check if file extension is allowed."""
102
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
103
+
104
+
105
+ def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, np.ndarray]:
106
+ """
107
+ Preprocess image for model input.
108
+
109
+ Returns:
110
+ Tuple of (preprocessed_tensor, original_image_array)
111
+ """
112
+ # Store original for visualization
113
+ original_image = np.array(image.convert('RGB'))
114
+
115
+ # Preprocess using the model's processor
116
+ inputs = processor(images=image, return_tensors="pt")
117
+ pixel_values = inputs['pixel_values'].to(DEVICE)
118
+
119
+ return pixel_values, original_image
120
+
121
+
122
+ def predict_image(pixel_values: torch.Tensor, top_k: int = 5) -> Dict:
123
+ """
124
+ Make prediction on preprocessed image.
125
+
126
+ Args:
127
+ pixel_values: Preprocessed image tensor
128
+ top_k: Number of top predictions to return
129
+
130
+ Returns:
131
+ Dictionary with prediction results
132
+ """
133
+ model.eval()
134
+ with torch.no_grad():
135
+ outputs = model(pixel_values)
136
+ logits = outputs.logits
137
+
138
+ # Get probabilities
139
+ probabilities = F.softmax(logits, dim=-1)[0]
140
+
141
+ # Get top-k predictions
142
+ top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(CLASS_NAMES)))
143
+
144
+ # Get predicted class
145
+ predicted_class_idx = logits.argmax(-1).item()
146
+ predicted_class_name = DISPLAY_NAMES[predicted_class_idx]
147
+ predicted_confidence = probabilities[predicted_class_idx].item()
148
+
149
+ # Prepare results
150
+ results = {
151
+ 'predicted_class': predicted_class_name,
152
+ 'predicted_class_raw': CLASS_NAMES[predicted_class_idx],
153
+ 'predicted_idx': predicted_class_idx,
154
+ 'confidence': float(predicted_confidence),
155
+ 'top_k_predictions': [
156
+ {
157
+ 'class': DISPLAY_NAMES[idx],
158
+ 'class_raw': CLASS_NAMES[idx],
159
+ 'probability': float(prob)
160
+ }
161
+ for idx, prob in zip(top_indices, top_probs)
162
+ ],
163
+ 'all_probabilities': {
164
+ DISPLAY_NAMES[i]: float(prob)
165
+ for i, prob in enumerate(probabilities)
166
+ }
167
+ }
168
+
169
+ return results
170
+
171
+
172
+ class ConvNeXtGradCAMWrapper(nn.Module):
173
+ """Wrapper for ConvNeXtV2ForImageClassification to make it compatible with GRAD-CAM."""
174
+
175
+ def __init__(self, model):
176
+ super().__init__()
177
+ self.model = model
178
+
179
+ def forward(self, x):
180
+ outputs = self.model(pixel_values=x)
181
+ return outputs.logits
182
+
183
+
184
+ def get_target_layers(model):
185
+ """Get the target layers for GRAD-CAM from ConvNeXt model."""
186
+ return [model.convnextv2.encoder.stages[-1].layers[-1]]
187
+
188
+
189
+ def apply_gradcam(
190
+ pixel_values: torch.Tensor,
191
+ original_image: np.ndarray,
192
+ target_class: Optional[int] = None
193
+ ) -> Dict:
194
+ """
195
+ Apply GRAD-CAM to visualize model attention.
196
+
197
+ Args:
198
+ pixel_values: Preprocessed image tensor
199
+ original_image: Original image as numpy array
200
+ target_class: Target class index (None for predicted class)
201
+
202
+ Returns:
203
+ Dictionary with GRAD-CAM visualization and metadata
204
+ """
205
+ # Wrap the model
206
+ wrapped_model = ConvNeXtGradCAMWrapper(model)
207
+
208
+ # Get target layers
209
+ target_layers = get_target_layers(model)
210
+
211
+ # Initialize GRAD-CAM
212
+ cam = GradCAM(model=wrapped_model, target_layers=target_layers)
213
+
214
+ # Get prediction
215
+ model.eval()
216
+ with torch.no_grad():
217
+ outputs = model(pixel_values)
218
+ logits = outputs.logits
219
+ predicted_class = logits.argmax(-1).item()
220
+ probabilities = F.softmax(logits, dim=-1)[0]
221
+
222
+ # Use predicted class if target not specified
223
+ if target_class is None:
224
+ target_class = predicted_class
225
+
226
+ # Create target for GRAD-CAM
227
+ targets = [ClassifierOutputTarget(target_class)]
228
+
229
+ # Generate GRAD-CAM
230
+ grayscale_cam = cam(input_tensor=pixel_values, targets=targets)
231
+ grayscale_cam = grayscale_cam[0, :]
232
+
233
+ # Resize original image to match CAM dimensions
234
+ cam_h, cam_w = grayscale_cam.shape
235
+ rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0
236
+
237
+ # Create visualization
238
+ visualization = show_cam_on_image(
239
+ rgb_image_for_overlay,
240
+ grayscale_cam,
241
+ use_rgb=True,
242
+ colormap=cv2.COLORMAP_JET
243
+ )
244
+
245
+ return {
246
+ 'grayscale_cam': grayscale_cam,
247
+ 'visualization': visualization,
248
+ 'predicted_class': predicted_class,
249
+ 'target_class': target_class,
250
+ 'confidence': float(probabilities[predicted_class].item())
251
+ }
252
+
253
+
254
+ def encode_image_to_base64(image_array: np.ndarray) -> str:
255
+ """Convert numpy array to base64 encoded PNG."""
256
+ # Convert to PIL Image
257
+ if image_array.dtype != np.uint8:
258
+ image_array = (image_array * 255).astype(np.uint8)
259
+
260
+ img = Image.fromarray(image_array)
261
+
262
+ # Save to bytes buffer
263
+ buffer = io.BytesIO()
264
+ img.save(buffer, format='PNG')
265
+ buffer.seek(0)
266
+
267
+ # Encode to base64
268
+ img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
269
+ return f"data:image/png;base64,{img_base64}"
270
+
271
+
272
+ # ========== API ENDPOINTS ==========
273
+
274
+ @app.route('/health', methods=['GET'])
275
+ def health_check():
276
+ """Health check endpoint."""
277
+ return jsonify({
278
+ 'status': 'healthy',
279
+ 'model_loaded': model is not None,
280
+ 'device': str(DEVICE),
281
+ 'model_id': HF_MODEL_ID
282
+ })
283
+
284
+
285
+ @app.route('/predict', methods=['POST'])
286
+ def predict():
287
+ """
288
+ Predict cervical cell classification.
289
+
290
+ Expects:
291
+ - image file in multipart/form-data
292
+ - Optional: top_k parameter for number of predictions
293
+
294
+ Returns:
295
+ JSON with prediction results
296
+ """
297
+ # Check if image file is present
298
+ if 'image' not in request.files:
299
+ return jsonify({'error': 'No image file provided'}), 400
300
+
301
+ file = request.files['image']
302
+
303
+ # Check if file is selected
304
+ if file.filename == '':
305
+ return jsonify({'error': 'No file selected'}), 400
306
+
307
+ # Check file extension
308
+ if not allowed_file(file.filename):
309
+ return jsonify({
310
+ 'error': f'File type not allowed. Allowed types: {", ".join(ALLOWED_EXTENSIONS)}'
311
+ }), 400
312
+
313
+ try:
314
+ # Get top_k parameter (default: 5)
315
+ top_k = int(request.form.get('top_k', 5))
316
+
317
+ # Load and preprocess image
318
+ image = Image.open(file.stream)
319
+ pixel_values, original_image = preprocess_image(image)
320
+
321
+ # Make prediction
322
+ results = predict_image(pixel_values, top_k=top_k)
323
+
324
+ return jsonify({
325
+ 'success': True,
326
+ 'prediction': results
327
+ })
328
+
329
+ except Exception as e:
330
+ return jsonify({
331
+ 'success': False,
332
+ 'error': str(e)
333
+ }), 500
334
+
335
+
336
+ @app.route('/predict_with_explainability', methods=['POST'])
337
+ def predict_with_explainability():
338
+ """
339
+ Predict cervical cell classification with GRAD-CAM visualization.
340
+
341
+ Expects:
342
+ - image file in multipart/form-data
343
+ - Optional: top_k parameter for number of predictions
344
+ - Optional: target_class parameter for GRAD-CAM visualization
345
+
346
+ Returns:
347
+ JSON with prediction results and GRAD-CAM visualization
348
+ """
349
+ # Check if image file is present
350
+ if 'image' not in request.files:
351
+ return jsonify({'error': 'No image file provided'}), 400
352
+
353
+ file = request.files['image']
354
+
355
+ # Check if file is selected
356
+ if file.filename == '':
357
+ return jsonify({'error': 'No file selected'}), 400
358
+
359
+ # Check file extension
360
+ if not allowed_file(file.filename):
361
+ return jsonify({
362
+ 'error': f'File type not allowed. Allowed types: {", ".join(ALLOWED_EXTENSIONS)}'
363
+ }), 400
364
+
365
+ try:
366
+ # Get parameters
367
+ top_k = int(request.form.get('top_k', 5))
368
+ target_class = request.form.get('target_class')
369
+ if target_class is not None:
370
+ target_class = int(target_class)
371
+
372
+ # Load and preprocess image
373
+ image = Image.open(file.stream)
374
+ pixel_values, original_image = preprocess_image(image)
375
+
376
+ # Make prediction
377
+ prediction_results = predict_image(pixel_values, top_k=top_k)
378
+
379
+ # Apply GRAD-CAM
380
+ gradcam_results = apply_gradcam(pixel_values, original_image, target_class)
381
+
382
+ # Encode visualization as base64
383
+ visualization_base64 = encode_image_to_base64(gradcam_results['visualization'])
384
+ original_image_base64 = encode_image_to_base64(original_image)
385
+
386
+ return jsonify({
387
+ 'success': True,
388
+ 'prediction': prediction_results,
389
+ 'explainability': {
390
+ 'method': 'GRAD-CAM',
391
+ 'target_class': DISPLAY_NAMES[gradcam_results['target_class']],
392
+ 'target_class_idx': gradcam_results['target_class'],
393
+ 'visualization': visualization_base64,
394
+ 'original_image': original_image_base64
395
+ }
396
+ })
397
+
398
+ except Exception as e:
399
+ return jsonify({
400
+ 'success': False,
401
+ 'error': str(e)
402
+ }), 500
403
+
404
+
405
+ @app.route('/classes', methods=['GET'])
406
+ def get_classes():
407
+ """Get list of available classes."""
408
+ return jsonify({
409
+ 'classes': [
410
+ {
411
+ 'idx': i,
412
+ 'name': display_name,
413
+ 'raw_name': raw_name
414
+ }
415
+ for i, (display_name, raw_name) in enumerate(zip(DISPLAY_NAMES, CLASS_NAMES))
416
+ ]
417
+ })
418
+
419
+
420
+ @app.route('/', methods=['GET'])
421
+ def index():
422
+ """Root endpoint with API information."""
423
+ return jsonify({
424
+ 'name': 'Project Phoenix - Cervical Cancer Cell Classification API',
425
+ 'version': '1.0.0',
426
+ 'model': HF_MODEL_ID,
427
+ 'device': str(DEVICE),
428
+ 'endpoints': {
429
+ '/health': 'GET - Health check',
430
+ '/predict': 'POST - Predict cell classification',
431
+ '/predict_with_explainability': 'POST - Predict with GRAD-CAM visualization',
432
+ '/classes': 'GET - Get available classes'
433
+ },
434
+ 'supported_formats': list(ALLOWED_EXTENSIONS),
435
+ 'max_file_size': f'{MAX_FILE_SIZE / (1024 * 1024)}MB'
436
+ })
437
+
438
+
439
+ # ========== MAIN ==========
440
+
441
+ if __name__ == '__main__':
442
+ # Get port from environment variable or default to 5000
443
+ port = int(os.getenv('PORT', 5000))
444
+
445
+ # Run the app
446
+ app.run(
447
+ host='0.0.0.0',
448
+ port=port,
449
+ debug=os.getenv('FLASK_ENV') == 'development'
450
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ safetensors
3
+ scikit-learn
4
+ transformers
5
+ numpy
6
+ pillow