ishanprogs commited on
Commit
73266fe
·
verified ·
1 Parent(s): b06a2df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -319
app.py CHANGED
@@ -1,5 +1,4 @@
1
- # app.py (Complete Final Version - Fixed SyntaxError in overlap loop)
2
-
3
  import gradio as gr
4
  import torch
5
  import clip
@@ -23,422 +22,244 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
23
  logger = logging.getLogger(__name__)
24
 
25
  # --- Constants ---
26
- # Damage segmentation classes (Order MUST match the training of 'model_best.pt')
27
  DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
28
  NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
29
-
30
- # Part segmentation classes (Order MUST match the training of 'partdetection_yolobest.pt')
31
  CAR_PART_CLASSES = [
32
  "Quarter-panel", "Front-wheel", "Back-window", "Trunk", "Front-door",
33
  "Rocker-panel", "Grille", "Windshield", "Front-window", "Back-door",
34
  "Headlight", "Back-wheel", "Back-windshield", "Hood", "Fender",
35
- "Tail-light", "License-plate", "Front-bumper", "Back-bumper", "Mirror",
36
- "Roof"
37
  ]
38
  NUM_CAR_PART_CLASSES = len(CAR_PART_CLASSES)
39
 
40
-
41
- # Paths within the Hugging Face Space repository (relative to app.py)
42
  CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
43
- DAMAGE_MODEL_WEIGHTS_PATH = "./best.pt" # Your YOLOv8 damage model weights
44
- PART_MODEL_WEIGHTS_PATH = "./partdetection_yolobest.pt" # Your YOLOv8 part model weights
45
-
46
- # Default Prediction Thresholds (can be overridden by sliders)
47
  DEFAULT_DAMAGE_PRED_THRESHOLD = 0.4
48
  DEFAULT_PART_PRED_THRESHOLD = 0.3
49
 
50
  # --- Device Setup ---
51
- if torch.cuda.is_available():
52
- DEVICE = "cuda"
53
- logger.info("CUDA available, using GPU.")
54
- else:
55
- DEVICE = "cpu"
56
- logger.info("CUDA not available, using CPU.")
57
 
58
- # --- MODEL LOADING (Load models globally ONCE on startup) ---
59
  print("--- Initializing Models ---")
60
- clip_model = None
61
- clip_preprocess = None
62
- clip_text_features = None
63
- damage_model = None
64
- part_model = None
65
- clip_load_error_msg = None
66
- damage_load_error_msg = None
67
- part_load_error_msg = None
68
 
69
-
70
- # --- Load CLIP Model (Model 1) ---
71
  try:
72
  logger.info("Loading CLIP model (ViT-B/16)...")
73
  clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE, jit=False)
74
  clip_model.eval()
75
- logger.info("CLIP model loaded.")
76
-
77
- logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...")
78
- if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
79
- raise FileNotFoundError(f"CLIP text features not found: {CLIP_TEXT_FEATURES_PATH}.")
80
  clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
81
- logger.info(f"CLIP text features loaded (dtype: {clip_text_features.dtype}).")
82
-
83
- except Exception as e:
84
- clip_load_error_msg = f"CLIP load error: {e}"
85
- logger.error(clip_load_error_msg, exc_info=True)
86
- clip_model = None
87
 
88
- # --- Load Damage Segmentation Model (Model 2 - YOLOv8) ---
89
  try:
90
- logger.info(f"Loading Damage Segmentation (YOLOv8) model from {DAMAGE_MODEL_WEIGHTS_PATH}...")
91
- if not os.path.exists(DAMAGE_MODEL_WEIGHTS_PATH):
92
- raise FileNotFoundError(f"Damage model weights not found: {DAMAGE_MODEL_WEIGHTS_PATH}.")
93
  damage_model = YOLO(DAMAGE_MODEL_WEIGHTS_PATH)
94
  damage_model.to(DEVICE)
95
- # Verify/Update class names
96
- loaded_damage_names = list(damage_model.names.values())
97
- if loaded_damage_names != DAMAGE_CLASSES:
98
- logger.warning(f"Mismatch: Defined DAMAGE_CLASSES vs names in {DAMAGE_MODEL_WEIGHTS_PATH}")
99
- DAMAGE_CLASSES = loaded_damage_names # Use names from model file
100
- logger.warning(f"Updated DAMAGE_CLASSES to: {DAMAGE_CLASSES}")
101
- logger.info("Damage Segmentation (YOLOv8) model loaded.")
102
- except Exception as e:
103
- damage_load_error_msg = f"Damage YOLO load error: {e}"
104
- logger.error(damage_load_error_msg, exc_info=True)
105
- damage_model = None
106
 
107
- # --- Load Part Segmentation Model (Model 3 - YOLOv8) ---
108
  try:
109
- logger.info(f"Loading Part Segmentation (YOLOv8) model from {PART_MODEL_WEIGHTS_PATH}...")
110
- if not os.path.exists(PART_MODEL_WEIGHTS_PATH):
111
- raise FileNotFoundError(f"Part model weights not found: {PART_MODEL_WEIGHTS_PATH}.")
112
  part_model = YOLO(PART_MODEL_WEIGHTS_PATH)
113
  part_model.to(DEVICE)
114
- # Verify/Update class names
115
- loaded_part_names = list(part_model.names.values())
116
- if loaded_part_names != CAR_PART_CLASSES:
117
- logger.warning(f"Mismatch: Defined CAR_PART_CLASSES vs names in {PART_MODEL_WEIGHTS_PATH}")
118
- CAR_PART_CLASSES = loaded_part_names # Use names from model file
119
- logger.warning(f"Updated CAR_PART_CLASSES to: {CAR_PART_CLASSES}")
120
- logger.info("Part Segmentation (YOLOv8) model loaded.")
121
- except Exception as e:
122
- part_load_error_msg = f"Part YOLO load error: {e}"
123
- logger.error(part_load_error_msg, exc_info=True)
124
- part_model = None
125
-
126
- print("--- Model loading process finished. ---")
127
- if clip_load_error_msg: print(f"CLIP STATUS: {clip_load_error_msg}")
128
- else: print("CLIP STATUS: Loaded OK.")
129
- if damage_load_error_msg: print(f"DAMAGE MODEL STATUS: {damage_load_error_msg}")
130
- else: print("DAMAGE MODEL STATUS: Loaded OK.")
131
- if part_load_error_msg: print(f"PART MODEL STATUS: {part_load_error_msg}")
132
- else: print("PART MODEL STATUS: Loaded OK.")
133
-
134
 
135
  # --- Prediction Functions ---
136
-
137
  def classify_image_clip(image_pil):
138
- """Classifies image using CLIP. Returns label and probability dictionary."""
139
- if clip_model is None or clip_text_features is None:
140
- logger.error(f"CLIP model or text features not loaded. Error: {clip_load_error_msg}")
141
- return "Error: CLIP Model Not Loaded", {"Error": 1.0}
142
-
143
- logger.info("Running CLIP classification...")
144
  try:
145
  if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB")
146
- t1 = time.time()
147
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
148
  with torch.no_grad():
149
  image_features = clip_model.encode_image(image_input)
150
  image_features /= image_features.norm(dim=-1, keepdim=True)
151
  text_features_matched = clip_text_features
152
  if image_features.dtype != clip_text_features.dtype:
153
- logger.warning(f"CLIP Dtype mismatch! Image: {image_features.dtype}, Text: {clip_text_features.dtype}. Converting text...")
154
  text_features_matched = clip_text_features.to(image_features.dtype)
155
- logit_scale = clip_model.logit_scale.exp()
156
- similarity = (image_features @ text_features_matched.T) * logit_scale
157
  probs = similarity.softmax(dim=-1).squeeze().cpu()
158
- t2 = time.time()
159
- logger.info(f"CLIP processing time: {t2-t1:.2f}s")
160
-
161
- car_prob = probs[0].item(); not_car_prob = probs[1].item()
162
- predicted_label = "Car" if car_prob > not_car_prob else "Not Car"
163
- prob_dict = {"Car": f"{car_prob:.3f}", "Not Car": f"{not_car_prob:.3f}"}
164
- return predicted_label, prob_dict
165
- except Exception as e:
166
- logger.error(f"Error during CLIP prediction function: {e}", exc_info=True)
167
- traceback.print_exc()
168
- return "Error during CLIP processing", {"Error": 1.0}
169
 
170
- # --- CORRECTED process_car_image Function ---
171
  def process_car_image(image_np_bgr, damage_threshold, part_threshold):
172
- """
173
- Runs damage and part segmentation (YOLOv8), calculates overlap, visualizes.
174
- Returns: combined_image_rgb (numpy), assignment_text (string)
175
- """
176
- # Check model availability
177
- if damage_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Damage model not loaded ({damage_load_error_msg})"
178
- if part_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Part model not loaded ({part_load_error_msg})"
179
 
180
  final_assignments = []
181
- annotated_image_bgr = image_np_bgr.copy() # Work on a copy
182
  img_h, img_w = image_np_bgr.shape[:2]
183
  logger.info("Starting combined YOLO processing...")
184
- yolo_start_time = time.time()
185
- im_tensor_gpu_for_annotator = None # Initialize
186
 
187
  try:
188
- # --- Create the image tensor ONCE for the annotator ---
 
189
  try:
190
- im_tensor_gpu_for_annotator = torch.from_numpy(image_np_bgr).to(DEVICE) # Keep HWC
191
- if not isinstance(im_tensor_gpu_for_annotator, torch.Tensor) or im_tensor_gpu_for_annotator.ndim != 3:
192
- raise ValueError("Failed to create valid HWC image tensor for annotator.")
193
- logger.info(f"Created image tensor for annotator (Shape: {im_tensor_gpu_for_annotator.shape}, Device: {im_tensor_gpu_for_annotator.device})")
 
 
 
 
194
  except Exception as e_tensor:
195
- logger.error(f"Could not create image tensor: {e_tensor}. Mask visualization might fail.", exc_info=True)
196
- im_tensor_gpu_for_annotator = None
197
 
198
- # --- 1. Predict Damages (YOLOv8) ---
199
- logger.info(f" Running Damage Segmentation (Threshold: {damage_threshold})...")
200
  damage_results = damage_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=damage_threshold)
201
  damage_result = damage_results[0]
202
- logger.info(f" Found {len(damage_result.boxes)} potential damages.")
203
  damage_masks_raw = damage_result.masks.data if damage_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
 
 
204
  damage_classes_ids_cpu = damage_result.boxes.cls.cpu().numpy().astype(int) if damage_result.boxes is not None else np.array([])
205
  damage_boxes_xyxy_cpu = damage_result.boxes.xyxy.cpu() if damage_result.boxes is not None else torch.empty((0,4))
206
 
207
- # --- 2. Predict Parts (YOLOv8) ---
208
- logger.info(f" Running Part Segmentation (Threshold: {part_threshold})...")
209
  part_results = part_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=part_threshold)
210
  part_result = part_results[0]
211
- logger.info(f" Found {len(part_result.boxes)} potential parts.")
212
  part_masks_raw = part_result.masks.data if part_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
 
 
213
  part_classes_ids_cpu = part_result.boxes.cls.cpu().numpy().astype(int) if part_result.boxes is not None else np.array([])
214
  part_boxes_xyxy_cpu = part_result.boxes.xyxy.cpu() if part_result.boxes is not None else torch.empty((0,4))
215
 
216
- yolo_end_time = time.time()
217
- logger.info(f" YOLO predictions took {yolo_end_time - yolo_start_time:.2f}s")
218
-
219
  # --- 3. Resize Masks ---
220
  def resize_masks(masks_tensor, target_h, target_w):
221
- """Resizes masks tensor to target H, W using CPU numpy and OpenCV."""
222
- masks_np_bool = masks_tensor.cpu().numpy().astype(bool)
223
- if masks_np_bool.shape[0] == 0: return np.array([])
224
- if masks_np_bool.ndim == 3 and masks_np_bool.shape[1] == target_h and masks_np_bool.shape[2] == target_w: return masks_np_bool
225
- if masks_np_bool.ndim == 2: masks_np_bool = np.expand_dims(masks_np_bool, axis=0)
226
- if masks_np_bool.ndim != 3: logger.error(f"Unexpected mask dim: {masks_np_bool.ndim}"); return np.array([])
227
- resized_masks_list = []
228
- for i in range(masks_np_bool.shape[0]):
229
- mask = masks_np_bool[i]
230
- mask_resized = cv2.resize(mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
231
- resized_masks_list.append(mask_resized.astype(bool))
232
- return np.array(resized_masks_list)
233
-
234
- resize_start_time = time.time()
235
  damage_masks_np = resize_masks(damage_masks_raw, img_h, img_w)
236
  part_masks_np = resize_masks(part_masks_raw, img_h, img_w)
237
- resize_end_time = time.time()
238
- logger.info(f" Mask resizing took {resize_end_time - resize_start_time:.2f}s")
239
-
240
 
241
  # --- 4. Calculate Overlap ---
242
- logger.info(" Calculating overlap...")
243
- overlap_start_time = time.time()
244
- if damage_masks_np.shape[0] > 0 and part_masks_np.shape[0] > 0:
245
- overlap_threshold = 0.4
246
- # *** CORRECTED OVERLAP LOOP SYNTAX ***
247
- for i in range(len(damage_masks_np)): # Iterate through each detected damage
248
- damage_mask = damage_masks_np[i]
249
- damage_class_id = damage_classes_ids_cpu[i]
250
-
251
- # Try getting damage name, skip if ID is invalid
252
- try:
253
- damage_name = DAMAGE_CLASSES[damage_class_id]
254
- except IndexError:
255
- logger.warning(f"Invalid damage ID {damage_class_id} found during overlap check. Skipping this damage.")
256
- continue # Go to the next damage mask
257
-
258
- # Check damage area (only if name was valid)
259
- damage_area = np.sum(damage_mask)
260
- if damage_area < 10: # Skip tiny masks
261
- continue
262
-
263
- # Initialize for finding best overlapping part for this damage
264
- max_overlap = 0
265
- assigned_part_name = "Unknown / Outside Parts" # Default
266
-
267
- # Inner loop for parts
268
- for j in range(len(part_masks_np)):
269
- part_mask = part_masks_np[j]
270
- part_class_id = part_classes_ids_cpu[j]
271
- try:
272
- part_name = CAR_PART_CLASSES[part_class_id]
273
- except IndexError:
274
- logger.warning(f"Invalid part ID {part_class_id} during overlap check.")
275
- continue # Skip this part
276
-
277
- intersection = np.logical_and(damage_mask, part_mask)
278
- overlap_ratio = np.sum(intersection) / damage_area if damage_area > 0 else 0
279
-
280
- if overlap_ratio > max_overlap:
281
- max_overlap = overlap_ratio
282
- # Assign based on threshold condition within the inner loop
283
- if max_overlap >= overlap_threshold:
284
- assigned_part_name = part_name
285
- else:
286
- # If max overlap is < threshold even for best part, keep default
287
- assigned_part_name = "Unknown / Outside Parts"
288
-
289
-
290
- # Append assignment result after checking all parts for the current damage
291
- assignment_desc = f"{damage_name} in {assigned_part_name}"
292
- if assigned_part_name == "Unknown / Outside Parts" and max_overlap > 0: # Add detail if there was *some* overlap
293
- assignment_desc += f" (Max Overlap < {overlap_threshold*100:.0f}%)"
294
- elif assigned_part_name == "Unknown / Outside Parts":
295
- assignment_desc += " (No overlap)" # Clarify if zero overlap
296
-
297
- final_assignments.append(assignment_desc)
298
- # *** END OF CORRECTED OVERLAP LOOP ***
299
-
300
- # Handle cases with no damages or no parts
301
  elif damage_masks_np.shape[0] > 0: final_assignments.append(f"{len(damage_masks_np)} damages found, but no parts detected/matched above threshold {part_threshold}.")
302
  elif part_masks_np.shape[0] > 0: final_assignments.append(f"No damages detected above threshold {damage_threshold}.")
303
  else: final_assignments.append(f"No damages or parts detected above thresholds.")
304
- overlap_end_time = time.time(); logger.info(f" Overlap calculation took {overlap_end_time - overlap_start_time:.2f}s"); logger.info(f" Assignment results: {final_assignments}")
305
-
306
 
307
  # --- 5. Visualization using YOLO Annotator ---
308
- logger.info(" Visualizing results...")
309
- vis_start_time = time.time()
310
- annotator = Annotator(annotated_image_bgr, line_width=2, example=part_model.names)
311
 
312
- # Draw PART masks first (Greenish)
313
- if part_result.masks is not None and im_tensor_gpu_for_annotator is not None:
314
  try:
 
315
  colors_part = [(0, random.randint(100, 200), 0) for _ in part_classes_ids_cpu]
316
- annotator.masks(part_masks_raw, colors=colors_part, im_gpu=im_tensor_gpu_for_annotator, alpha=0.3)
 
 
 
317
  for box, cls_id in zip(part_boxes_xyxy_cpu, part_classes_ids_cpu):
318
- try: label = f"{CAR_PART_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(0, 200, 0))
319
- except IndexError: logger.warning(f"Invalid part ID {cls_id} during drawing"); continue
320
- except Exception as e_part_vis: logger.error(f"Error drawing part masks/boxes: {e_part_vis}", exc_info=True)
321
- elif part_result.masks is not None:
322
- logger.warning("Could not draw part masks because image tensor creation failed.")
323
 
324
- # Draw DAMAGE masks second (Reddish)
325
- if damage_result.masks is not None and im_tensor_gpu_for_annotator is not None:
326
  try:
 
327
  colors_dmg = [(random.randint(100, 200), 0, 0) for _ in damage_classes_ids_cpu]
328
- annotator.masks(damage_masks_raw, colors=colors_dmg, im_gpu=im_tensor_gpu_for_annotator, alpha=0.4)
 
 
 
 
 
 
329
  for box, cls_id in zip(damage_boxes_xyxy_cpu, damage_classes_ids_cpu):
330
  try: label = f"{DAMAGE_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(200, 0, 0))
331
- except IndexError: logger.warning(f"Invalid damage ID {cls_id} during drawing"); continue
332
- except Exception as e_dmg_vis: logger.error(f"Error drawing damage masks/boxes: {e_dmg_vis}", exc_info=True)
333
- elif damage_result.masks is not None:
334
- logger.warning("Could not draw damage masks because image tensor creation failed.")
335
 
336
- annotated_image_bgr = annotator.result() # Get final BGR image
337
- vis_end_time = time.time()
338
- logger.info(f" Visualization took {vis_end_time - vis_start_time:.2f}s")
339
 
340
  except Exception as e:
341
- logger.error(f"Error during combined processing: {e}", exc_info=True)
342
- traceback.print_exc()
343
  final_assignments.append("Error during segmentation/processing.")
344
 
345
- # --- Prepare output ---
346
  assignment_text = "\n".join(final_assignments) if final_assignments else "No damage assignments generated."
347
- # Convert final annotated image to RGB for Gradio display
348
  final_output_image_rgb = cv2.cvtColor(annotated_image_bgr, cv2.COLOR_BGR2RGB)
349
-
350
  return final_output_image_rgb, assignment_text
351
 
352
-
353
  # --- Main Gradio Function ---
354
  def predict_pipeline(image_np_input, damage_thresh, part_thresh):
355
- """
356
- Main pipeline: Classify -> Segment -> Assign -> Visualize
357
- """
358
- if image_np_input is None:
359
- return "Please upload an image.", {}, None, "N/A"
360
-
361
- logger.info(f"--- New Request (Damage Thr: {damage_thresh:.2f}, Part Thr: {part_thresh:.2f}) ---")
362
- start_time = time.time()
363
- # Gradio provides RGB numpy, convert to BGR for OpenCV/YOLO internal, PIL for CLIP
364
- image_np_bgr = cv2.cvtColor(image_np_input, cv2.COLOR_RGB2BGR)
365
- image_pil = Image.fromarray(image_np_input) # Input numpy is RGB
366
-
367
- final_output_image = None
368
- assignment_text = "Processing..."
369
- classification_result = "Error"
370
- probabilities = {}
371
-
372
- # --- Stage 1: CLIP Classification ---
373
- try:
374
- classification_result, probabilities = classify_image_clip(image_pil)
375
- except Exception as e:
376
- logger.error(f"Error in CLIP stage: {e}", exc_info=True); assignment_text = f"CLIP Error: {e}";
377
- # Prepare original image for display in case of error
378
- final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
379
-
380
- # --- Stage 2 & 3: Segmentation and Assignment (if 'Car') ---
381
  if classification_result == "Car":
382
- logger.info("Image classified as Car. Running segmentation and assignment...")
383
- try:
384
- final_output_image, assignment_text = process_car_image(image_np_bgr, damage_thresh, part_thresh)
385
- except Exception as e:
386
- logger.error(f"Error in segmentation/assignment stage: {e}", exc_info=True); assignment_text = f"Segmentation Error: {e}";
387
- final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
388
-
389
- elif classification_result == "Not Car":
390
- logger.info("Image classified as Not Car."); final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Image classified as Not Car."
391
- # Handle CLIP error case (already logged)
392
- elif final_output_image is None: # Ensure image is set if CLIP error occurred before seg stage
393
- logger.error("CLIP classification failed."); final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Error during classification."
394
-
395
-
396
- # --- Cleanup ---
397
- gc.collect()
398
- if torch.cuda.is_available():
399
- torch.cuda.empty_cache()
400
-
401
- end_time = time.time()
402
- logger.info(f"Total processing time: {end_time - start_time:.2f} seconds.")
403
- # Return all results
404
  return classification_result, probabilities, final_output_image, assignment_text
405
 
406
-
407
  # --- Gradio Interface ---
408
  logger.info("Setting up Gradio interface...")
409
-
410
- title = "🚗 Car Damage Detection" # Updated Title
411
- description = """
412
- 1. **Upload** an image of a vehicle.
413
- 2. **Classification:** Determines if the image contains a car (using CLIP).
414
- 3. **Segmentation:** If it's a car, detects car parts and damages (using YOLOv8 for both). Adjust confidence thresholds using sliders.
415
- 4. **Assignment:** Assigns detected damages to the corresponding car part based on mask overlap (threshold >40%).
416
- 5. **Output:** Shows the image with overlaid masks (Green=Part, Red=Damage) and lists the damage assignments.
417
- """
418
- examples = [] # Add example image paths if uploaded (e.g., ["./example1.jpg"])
419
-
420
- # Define Inputs and Outputs including sliders
421
- input_image = gr.Image(type="numpy", label="Upload Car Image")
422
- damage_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_DAMAGE_PRED_THRESHOLD, label="Damage Confidence Threshold")
423
- part_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_PART_PRED_THRESHOLD, label="Part Confidence Threshold")
424
-
425
- output_classification = gr.Textbox(label="1. Classification Result")
426
- output_probabilities = gr.Label(label="Classification Probabilities")
427
- output_image_display = gr.Image(type="numpy", label="3. Segmentation Visualization")
428
- output_assignment = gr.Textbox(label="2. Damage Assignments", lines=5, interactive=False)
429
-
430
-
431
- # Launch the interface with updated inputs/title
432
- iface = gr.Interface(
433
- fn=predict_pipeline,
434
- inputs=[input_image, damage_threshold_slider, part_threshold_slider], # Sliders added
435
- outputs=[output_classification, output_probabilities, output_image_display, output_assignment],
436
- title=title, # Updated title
437
- description=description,
438
- examples=examples,
439
- allow_flagging="never"
440
- )
441
-
442
- if __name__ == "__main__":
443
- logger.info("Launching Gradio app...")
444
- iface.launch() # Set share=True for public link if needed
 
1
+ # app.py
 
2
  import gradio as gr
3
  import torch
4
  import clip
 
22
  logger = logging.getLogger(__name__)
23
 
24
  # --- Constants ---
 
25
  DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
26
  NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
 
 
27
  CAR_PART_CLASSES = [
28
  "Quarter-panel", "Front-wheel", "Back-window", "Trunk", "Front-door",
29
  "Rocker-panel", "Grille", "Windshield", "Front-window", "Back-door",
30
  "Headlight", "Back-wheel", "Back-windshield", "Hood", "Fender",
31
+ "Tail-light", "License-plate", "Front-bumper", "Back-bumper", "Mirror", "Roof"
 
32
  ]
33
  NUM_CAR_PART_CLASSES = len(CAR_PART_CLASSES)
34
 
 
 
35
  CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
36
+ DAMAGE_MODEL_WEIGHTS_PATH = "./best.pt"
37
+ PART_MODEL_WEIGHTS_PATH = "./partdetection_yolobest.pt"
 
 
38
  DEFAULT_DAMAGE_PRED_THRESHOLD = 0.4
39
  DEFAULT_PART_PRED_THRESHOLD = 0.3
40
 
41
  # --- Device Setup ---
42
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
43
+ logger.info(f"Using device: {DEVICE}")
 
 
 
 
44
 
45
+ # --- MODEL LOADING ---
46
  print("--- Initializing Models ---")
47
+ clip_model, clip_preprocess, clip_text_features = None, None, None
48
+ damage_model, part_model = None, None
49
+ clip_load_error_msg, damage_load_error_msg, part_load_error_msg = None, None, None
 
 
 
 
 
50
 
 
 
51
  try:
52
  logger.info("Loading CLIP model (ViT-B/16)...")
53
  clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE, jit=False)
54
  clip_model.eval()
55
+ if not os.path.exists(CLIP_TEXT_FEATURES_PATH): raise FileNotFoundError(f"CLIP text features not found: {CLIP_TEXT_FEATURES_PATH}.")
 
 
 
 
56
  clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
57
+ logger.info(f"CLIP loaded (Text Features dtype: {clip_text_features.dtype}).")
58
+ except Exception as e: clip_load_error_msg = f"CLIP load error: {e}"; logger.error(clip_load_error_msg, exc_info=True)
 
 
 
 
59
 
 
60
  try:
61
+ logger.info(f"Loading Damage YOLOv8 model from {DAMAGE_MODEL_WEIGHTS_PATH}...")
62
+ if not os.path.exists(DAMAGE_MODEL_WEIGHTS_PATH): raise FileNotFoundError(f"Damage model weights not found: {DAMAGE_MODEL_WEIGHTS_PATH}.")
 
63
  damage_model = YOLO(DAMAGE_MODEL_WEIGHTS_PATH)
64
  damage_model.to(DEVICE)
65
+ logger.info(f"Damage model task: {damage_model.task}")
66
+ if damage_model.task != 'segment':
67
+ damage_load_error_msg = f"CRITICAL ERROR: Damage model task is {damage_model.task}, not 'segment'. This model won't produce masks!"
68
+ logger.error(damage_load_error_msg)
69
+ damage_model = None # Invalidate model
70
+ else:
71
+ loaded_damage_names = list(damage_model.names.values())
72
+ if loaded_damage_names != DAMAGE_CLASSES:
73
+ logger.warning(f"Mismatch: Defined DAMAGE_CLASSES vs names in {DAMAGE_MODEL_WEIGHTS_PATH}"); DAMAGE_CLASSES = loaded_damage_names; logger.warning(f"Updated DAMAGE_CLASSES to: {DAMAGE_CLASSES}")
74
+ logger.info("Damage YOLOv8 model loaded.")
75
+ except Exception as e: damage_load_error_msg = f"Damage YOLO load error: {e}"; logger.error(damage_load_error_msg, exc_info=True); damage_model = None
76
 
 
77
  try:
78
+ logger.info(f"Loading Part YOLOv8 model from {PART_MODEL_WEIGHTS_PATH}...")
79
+ if not os.path.exists(PART_MODEL_WEIGHTS_PATH): raise FileNotFoundError(f"Part model weights not found: {PART_MODEL_WEIGHTS_PATH}.")
 
80
  part_model = YOLO(PART_MODEL_WEIGHTS_PATH)
81
  part_model.to(DEVICE)
82
+ logger.info(f"Part model task: {part_model.task}")
83
+ if part_model.task != 'segment':
84
+ part_load_error_msg = f"CRITICAL ERROR: Part model task is {part_model.task}, not 'segment'. This model won't produce masks!"
85
+ logger.error(part_load_error_msg)
86
+ part_model = None # Invalidate model
87
+ else:
88
+ loaded_part_names = list(part_model.names.values())
89
+ if loaded_part_names != CAR_PART_CLASSES:
90
+ logger.warning(f"Mismatch: Defined CAR_PART_CLASSES vs names in {PART_MODEL_WEIGHTS_PATH}"); CAR_PART_CLASSES = loaded_part_names; logger.warning(f"Updated CAR_PART_CLASSES to: {CAR_PART_CLASSES}")
91
+ logger.info("Part YOLOv8 model loaded.")
92
+ except Exception as e: part_load_error_msg = f"Part YOLO load error: {e}"; logger.error(part_load_error_msg, exc_info=True); part_model = None
93
+
94
+ print("--- Model loading process finished. ---");
95
+ if clip_load_error_msg: print(f"CLIP STATUS: {clip_load_error_msg}"); else: print("CLIP STATUS: Loaded OK.")
96
+ if damage_load_error_msg: print(f"DAMAGE MODEL STATUS: {damage_load_error_msg}"); else: print("DAMAGE MODEL STATUS: Loaded OK.")
97
+ if part_load_error_msg: print(f"PART MODEL STATUS: {part_load_error_msg}"); else: print("PART MODEL STATUS: Loaded OK.")
 
 
 
 
98
 
99
  # --- Prediction Functions ---
 
100
  def classify_image_clip(image_pil):
101
+ if clip_model is None: return "Error: CLIP Model Not Loaded", {"Error": 1.0}
 
 
 
 
 
102
  try:
103
  if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB")
 
104
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
105
  with torch.no_grad():
106
  image_features = clip_model.encode_image(image_input)
107
  image_features /= image_features.norm(dim=-1, keepdim=True)
108
  text_features_matched = clip_text_features
109
  if image_features.dtype != clip_text_features.dtype:
 
110
  text_features_matched = clip_text_features.to(image_features.dtype)
111
+ similarity = (image_features @ text_features_matched.T) * clip_model.logit_scale.exp()
 
112
  probs = similarity.softmax(dim=-1).squeeze().cpu()
113
+ return ("Car" if probs[0] > probs[1] else "Not Car"), {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
114
+ except Exception as e: logger.error(f"CLIP Error: {e}", exc_info=True); return "Error: CLIP", {"Error": 1.0}
 
 
 
 
 
 
 
 
 
115
 
 
116
  def process_car_image(image_np_bgr, damage_threshold, part_threshold):
117
+ if damage_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Damage model failed to load ({damage_load_error_msg})"
118
+ if part_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Part model failed to load ({part_load_error_msg})"
119
+ if damage_model.task != 'segment': return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Damage model is not a segmentation model."
120
+ if part_model.task != 'segment': return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Part model is not a segmentation model."
 
 
 
121
 
122
  final_assignments = []
123
+ annotated_image_bgr = image_np_bgr.copy()
124
  img_h, img_w = image_np_bgr.shape[:2]
125
  logger.info("Starting combined YOLO processing...")
126
+ im_tensor_gpu_for_annotator = None
 
127
 
128
  try:
129
+ # --- Prepare Image Tensor for Annotator ---
130
+ logger.info("Preparing image tensor for annotator...")
131
  try:
132
+ if image_np_bgr.dtype != np.uint8:
133
+ logger.warning(f"Converting input image from {image_np_bgr.dtype} to uint8 for tensor creation.")
134
+ image_np_uint8 = image_np_bgr.astype(np.uint8)
135
+ else:
136
+ image_np_uint8 = image_np_bgr
137
+ # Create tensor in HWC format on the correct device
138
+ im_tensor_gpu_for_annotator = torch.from_numpy(image_np_uint8).to(DEVICE)
139
+ logger.info(f"Image tensor for annotator: shape={im_tensor_gpu_for_annotator.shape}, dtype={im_tensor_gpu_for_annotator.dtype}, device={im_tensor_gpu_for_annotator.device}")
140
  except Exception as e_tensor:
141
+ logger.error(f"Could not create image tensor: {e_tensor}. Mask visualization will fail.", exc_info=True)
142
+ im_tensor_gpu_for_annotator = None # Set to None if conversion fails
143
 
144
+ # --- 1. Predict Damages ---
145
+ logger.info(f"Running Damage Segmentation (Threshold: {damage_threshold})...")
146
  damage_results = damage_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=damage_threshold)
147
  damage_result = damage_results[0]
148
+ logger.info(f"Found {len(damage_result.boxes)} potential damages.")
149
  damage_masks_raw = damage_result.masks.data if damage_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
150
+ if damage_result.masks is None: logger.warning("No damage masks in result! Check if damage model is segmentation type.")
151
+ else: logger.info(f"Damage masks available: shape={damage_masks_raw.shape if damage_masks_raw.numel() > 0 else 'Empty'}")
152
  damage_classes_ids_cpu = damage_result.boxes.cls.cpu().numpy().astype(int) if damage_result.boxes is not None else np.array([])
153
  damage_boxes_xyxy_cpu = damage_result.boxes.xyxy.cpu() if damage_result.boxes is not None else torch.empty((0,4))
154
 
155
+ # --- 2. Predict Parts ---
156
+ logger.info(f"Running Part Segmentation (Threshold: {part_threshold})...")
157
  part_results = part_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=part_threshold)
158
  part_result = part_results[0]
159
+ logger.info(f"Found {len(part_result.boxes)} potential parts.")
160
  part_masks_raw = part_result.masks.data if part_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
161
+ if part_result.masks is None: logger.warning("No part masks in result! Check if part model is segmentation type.")
162
+ else: logger.info(f"Part masks available: shape={part_masks_raw.shape if part_masks_raw.numel() > 0 else 'Empty'}")
163
  part_classes_ids_cpu = part_result.boxes.cls.cpu().numpy().astype(int) if part_result.boxes is not None else np.array([])
164
  part_boxes_xyxy_cpu = part_result.boxes.xyxy.cpu() if part_result.boxes is not None else torch.empty((0,4))
165
 
 
 
 
166
  # --- 3. Resize Masks ---
167
  def resize_masks(masks_tensor, target_h, target_w):
168
+ # ... (resize logic remains the same - uses CPU numpy) ...
169
+ masks_np_bool = masks_tensor.cpu().numpy().astype(bool); if masks_np_bool.shape[0] == 0 or (masks_np_bool.shape[1] == target_h and masks_np_bool.shape[2] == target_w): return masks_np_bool; resized_masks_list = []; for i in range(masks_np_bool.shape[0]): mask = masks_np_bool[i]; mask_resized = cv2.resize(mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST); resized_masks_list.append(mask_resized.astype(bool)); return np.array(resized_masks_list)
 
 
 
 
 
 
 
 
 
 
 
 
170
  damage_masks_np = resize_masks(damage_masks_raw, img_h, img_w)
171
  part_masks_np = resize_masks(part_masks_raw, img_h, img_w)
 
 
 
172
 
173
  # --- 4. Calculate Overlap ---
174
+ logger.info("Calculating overlap...")
175
+ # ... (Overlap calculation logic remains the same - uses CPU numpy) ...
176
+ if damage_masks_np.shape[0] > 0 and part_masks_np.shape[0] > 0: overlap_threshold = 0.4;
177
+ for i in range(len(damage_masks_np)): damage_mask = damage_masks_np[i]; damage_class_id = damage_classes_ids_cpu[i]; try: damage_name = DAMAGE_CLASSES[damage_class_id]; except IndexError: logger.warning(f"Invalid damage ID {damage_class_id}"); continue; damage_area = np.sum(damage_mask); if damage_area < 10: continue; max_overlap = 0; assigned_part_name = "Unknown / Outside Parts";
178
+ for j in range(len(part_masks_np)): part_mask = part_masks_np[j]; part_class_id = part_classes_ids_cpu[j]; try: part_name = CAR_PART_CLASSES[part_class_id]; except IndexError: logger.warning(f"Invalid part ID {part_class_id}"); continue; intersection = np.logical_and(damage_mask, part_mask); overlap_ratio = np.sum(intersection) / damage_area if damage_area > 0 else 0; if overlap_ratio > max_overlap: max_overlap = overlap_ratio; if max_overlap >= overlap_threshold: assigned_part_name = part_name;
179
+ assignment_desc = f"{damage_name} in {assigned_part_name}"; if assigned_part_name == "Unknown / Outside Parts": assignment_desc += f" (Overlap < {overlap_threshold*100:.0f}%)"; final_assignments.append(assignment_desc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  elif damage_masks_np.shape[0] > 0: final_assignments.append(f"{len(damage_masks_np)} damages found, but no parts detected/matched above threshold {part_threshold}.")
181
  elif part_masks_np.shape[0] > 0: final_assignments.append(f"No damages detected above threshold {damage_threshold}.")
182
  else: final_assignments.append(f"No damages or parts detected above thresholds.")
183
+ logger.info(f" Assignment results: {final_assignments}")
 
184
 
185
  # --- 5. Visualization using YOLO Annotator ---
186
+ logger.info("Visualizing results...")
187
+ annotator = Annotator(annotated_image_bgr, line_width=2, example=CAR_PART_CLASSES)
 
188
 
189
+ # Draw PART masks
190
+ if part_masks_raw.numel() > 0 and im_tensor_gpu_for_annotator is not None:
191
  try:
192
+ logger.info("Attempting to draw part masks...")
193
  colors_part = [(0, random.randint(100, 200), 0) for _ in part_classes_ids_cpu]
194
+ mask_data_part = part_masks_raw
195
+ if mask_data_part.device != im_tensor_gpu_for_annotator.device: mask_data_part = mask_data_part.to(im_tensor_gpu_for_annotator.device)
196
+ annotator.masks(mask_data_part, colors=colors_part, im_gpu=im_tensor_gpu_for_annotator, alpha=0.3)
197
+ logger.info("Successfully drew part masks.")
198
  for box, cls_id in zip(part_boxes_xyxy_cpu, part_classes_ids_cpu):
199
+ try: label = f"{CAR_PART_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(0, 200, 0))
200
+ except IndexError: logger.warning(f"Invalid part ID {cls_id} during drawing")
201
+ except Exception as e_part_vis: logger.error(f"Error drawing part masks/boxes: {e_part_vis}", exc_info=True); traceback.print_exc()
202
+ elif part_masks_raw.numel() > 0: logger.warning("Part masks exist but image tensor for annotator is None. Skipping part mask drawing.")
203
+
204
 
205
+ # Draw DAMAGE masks
206
+ if damage_masks_raw.numel() > 0 and im_tensor_gpu_for_annotator is not None:
207
  try:
208
+ logger.info("Attempting to draw damage masks...")
209
  colors_dmg = [(random.randint(100, 200), 0, 0) for _ in damage_classes_ids_cpu]
210
+ mask_data_dmg = damage_masks_raw
211
+ logger.info(f"Damage mask data: shape={mask_data_dmg.shape}, dtype={mask_data_dmg.dtype}, device={mask_data_dmg.device}")
212
+ if mask_data_dmg.device != im_tensor_gpu_for_annotator.device:
213
+ logger.warning(f"Moving damage masks to match image tensor device ({im_tensor_gpu_for_annotator.device})")
214
+ mask_data_dmg = mask_data_dmg.to(im_tensor_gpu_for_annotator.device)
215
+ annotator.masks(mask_data_dmg, colors=colors_dmg, im_gpu=im_tensor_gpu_for_annotator, alpha=0.4)
216
+ logger.info("Successfully drew damage masks.")
217
  for box, cls_id in zip(damage_boxes_xyxy_cpu, damage_classes_ids_cpu):
218
  try: label = f"{DAMAGE_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(200, 0, 0))
219
+ except IndexError: logger.warning(f"Invalid damage ID {cls_id} during drawing")
220
+ except Exception as e_dmg_vis: logger.error(f"Error drawing damage masks/boxes: {e_dmg_vis}", exc_info=True); traceback.print_exc()
221
+ elif damage_masks_raw.numel() > 0: logger.warning("Damage masks exist but image tensor for annotator is None. Skipping damage mask drawing.")
 
222
 
223
+
224
+ annotated_image_bgr = annotator.result()
 
225
 
226
  except Exception as e:
227
+ logger.error(f"Error during combined processing: {e}", exc_info=True); traceback.print_exc()
 
228
  final_assignments.append("Error during segmentation/processing.")
229
 
 
230
  assignment_text = "\n".join(final_assignments) if final_assignments else "No damage assignments generated."
 
231
  final_output_image_rgb = cv2.cvtColor(annotated_image_bgr, cv2.COLOR_BGR2RGB)
 
232
  return final_output_image_rgb, assignment_text
233
 
 
234
  # --- Main Gradio Function ---
235
  def predict_pipeline(image_np_input, damage_thresh, part_thresh):
236
+ if image_np_input is None: return "Please upload an image.", {}, None, "N/A";
237
+ logger.info(f"--- New Request (Damage Thr: {damage_thresh:.2f}, Part Thr: {part_thresh:.2f}) ---"); start_time = time.time();
238
+ image_np_bgr = cv2.cvtColor(image_np_input, cv2.COLOR_RGB2BGR); image_pil = Image.fromarray(image_np_input);
239
+ final_output_image, assignment_text, classification_result, probabilities = None, "Processing...", "Error", {};
240
+ try: classification_result, probabilities = classify_image_clip(image_pil)
241
+ except Exception as e: logger.error(f"CLIP Error: {e}", exc_info=True); assignment_text = f"CLIP Error: {e}"; final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if classification_result == "Car":
243
+ try: final_output_image, assignment_text = process_car_image(image_np_bgr, damage_thresh, part_thresh)
244
+ except Exception as e: logger.error(f"Seg/Assign Error: {e}", exc_info=True); assignment_text = f"Seg Error: {e}"; final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB);
245
+ elif classification_result == "Not Car": final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Image classified as Not Car.";
246
+ elif final_output_image is None: final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Error during classification.";
247
+ gc.collect();
248
+ if torch.cuda.is_available(): torch.cuda.empty_cache();
249
+ logger.info(f"Total processing time: {time.time() - start_time:.2f}s.");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  return classification_result, probabilities, final_output_image, assignment_text
251
 
 
252
  # --- Gradio Interface ---
253
  logger.info("Setting up Gradio interface...")
254
+ title = "🚗 Car Damage Detection"
255
+ description = "1. Upload... 2. Classify... 3. Segment... 4. Assign... 5. Output..." # Shortened
256
+ input_image = gr.Image(type="numpy", label="Upload Car Image");
257
+ damage_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_DAMAGE_PRED_THRESHOLD, label="Damage Confidence Threshold");
258
+ part_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_PART_PRED_THRESHOLD, label="Part Confidence Threshold");
259
+ output_classification = gr.Textbox(label="1. Classification Result");
260
+ output_probabilities = gr.Label(label="Classification Probabilities");
261
+ output_image_display = gr.Image(type="numpy", label="3. Segmentation Visualization");
262
+ output_assignment = gr.Textbox(label="2. Damage Assignments", lines=5, interactive=False);
263
+ iface = gr.Interface(fn=predict_pipeline, inputs=[input_image, damage_threshold_slider, part_threshold_slider], outputs=[output_classification, output_probabilities, output_image_display, output_assignment], title=title, description=description, allow_flagging="never" );
264
+
265
+ if __name__ == "__main__": logger.info("Launching Gradio app..."); iface.launch()