hasnatz commited on
Commit
b0e3fdc
·
verified ·
1 Parent(s): cde2883

Create image_inference.py

Browse files
Files changed (1) hide show
  1. image_inference.py +189 -0
image_inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import io
4
+ import requests
5
+ import onnxruntime as ort
6
+ import numpy as np
7
+ import os
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import matplotlib
10
+
11
+ # ---------------------------
12
+ # Font helper
13
+ # ---------------------------
14
+ def get_font(size=20):
15
+ font_name = matplotlib.rcParams['font.sans-serif'][0]
16
+ font_path = matplotlib.font_manager.findfont(font_name)
17
+ return ImageFont.truetype(font_path, size)
18
+
19
+ # ---------------------------
20
+ # Colors and classes
21
+ # ---------------------------
22
+ COLOR_PALETTE = [
23
+ (220, 20, 60), # Crimson Red
24
+ (0, 128, 0), # Green
25
+ (0, 0, 255), # Blue
26
+ (255, 140, 0), # Dark Orange
27
+ (255, 215, 0), # Gold
28
+ (138, 43, 226), # Blue Violet
29
+ (0, 206, 209), # Dark Turquoise
30
+ (255, 105, 180), # Hot Pink
31
+ (70, 130, 180), # Steel Blue
32
+ (46, 139, 87), # Sea Green
33
+ (210, 105, 30), # Chocolate
34
+ (123, 104, 238), # Medium Slate Blue
35
+ (199, 21, 133), # Medium Violet Red
36
+ ]
37
+
38
+ classes = [
39
+ 'None','Boots','C-worker','Cone','Construction-hat','Crane',
40
+ 'Excavator','Gloves','Goggles','Ladder','Mask','Truck','Vest'
41
+ ]
42
+
43
+ CLASS_COLORS = {cls: COLOR_PALETTE[i % len(COLOR_PALETTE)] for i, cls in enumerate(classes)}
44
+
45
+ # ---------------------------
46
+ # Image loading
47
+ # ---------------------------
48
+ def open_image(path):
49
+ """Load image from local path or URL."""
50
+ if path.startswith('http://') or path.startswith('https://'):
51
+ img = Image.open(io.BytesIO(requests.get(path).content))
52
+ else:
53
+ if os.path.exists(path):
54
+ img = Image.open(path)
55
+ else:
56
+ raise FileNotFoundError(f"The file {path} does not exist.")
57
+ return img
58
+
59
+ # ---------------------------
60
+ # Utilities
61
+ # ---------------------------
62
+ def sigmoid(x):
63
+ return 1 / (1 + np.exp(-x))
64
+
65
+ def box_cxcywh_to_xyxy_numpy(x):
66
+ """Convert [cx, cy, w, h] box format to [x1, y1, x2, y2]."""
67
+ x_c, y_c, w, h = np.split(x, 4, axis=-1)
68
+ b = np.concatenate([
69
+ x_c - 0.5 * np.clip(w, a_min=0.0, a_max=None),
70
+ y_c - 0.5 * np.clip(h, a_min=0.0, a_max=None),
71
+ x_c + 0.5 * np.clip(w, a_min=0.0, a_max=None),
72
+ y_c + 0.5 * np.clip(h, a_min=0.0, a_max=None)
73
+ ], axis=-1)
74
+ return b
75
+
76
+ # ---------------------------
77
+ # RTDETR ONNX Inference
78
+ # ---------------------------
79
+ class RTDETR_ONNX:
80
+ MEANS = [0.485, 0.456, 0.406]
81
+ STDS = [0.229, 0.224, 0.225]
82
+
83
+ def __init__(self, onnx_model_path):
84
+ self.ort_session = ort.InferenceSession(onnx_model_path)
85
+ input_info = self.ort_session.get_inputs()[0]
86
+ self.input_height, self.input_width = input_info.shape[2:]
87
+
88
+ def _preprocess_image(self, image):
89
+ """Preprocess the input image for inference."""
90
+
91
+ image = image.resize((self.input_width, self.input_height))
92
+ image = np.array(image).astype(np.float32) / 255.0
93
+ image = ((image - self.MEANS) / self.STDS).astype(np.float32)
94
+ image = np.transpose(image, (2, 0, 1)) # HWC → CHW
95
+ image = np.expand_dims(image, axis=0) # Add batch
96
+ return image
97
+
98
+ def _post_process(self, outputs, origin_height, origin_width, confidence_threshold, max_number_boxes):
99
+ """Post-process raw outputs into scores, labels, and boxes."""
100
+ pred_boxes, pred_logits = outputs
101
+ prob = sigmoid(pred_logits)
102
+
103
+ # Flatten and get top-k
104
+ flat_prob = prob[0].flatten()
105
+ topk_indexes = np.argsort(flat_prob)[-max_number_boxes:][::-1]
106
+ topk_values = np.take_along_axis(flat_prob, topk_indexes, axis=0)
107
+ scores = topk_values
108
+ topk_boxes = topk_indexes // pred_logits.shape[2]
109
+ labels = topk_indexes % pred_logits.shape[2]
110
+
111
+ # Gather boxes
112
+ boxes = box_cxcywh_to_xyxy_numpy(pred_boxes[0])
113
+ boxes = np.take_along_axis(
114
+ boxes,
115
+ np.expand_dims(topk_boxes, axis=-1).repeat(4, axis=-1),
116
+ axis=0
117
+ )
118
+
119
+ # Rescale boxes
120
+ target_sizes = np.array([[origin_height, origin_width]])
121
+ img_h, img_w = target_sizes[:, 0], target_sizes[:, 1]
122
+ scale_fct = np.stack([img_w, img_h, img_w, img_h], axis=1)
123
+ boxes = boxes * scale_fct[0, :]
124
+
125
+ # Filter by confidence
126
+ keep = scores > confidence_threshold
127
+ scores = scores[keep]
128
+ labels = labels[keep]
129
+ boxes = boxes[keep]
130
+
131
+ return scores, labels, boxes
132
+
133
+ def annotate_detections(self, image, boxes, labels, scores=None):
134
+ """Draw bounding boxes and class labels, return PIL.Image."""
135
+ draw = ImageDraw.Draw(image)
136
+ font = get_font()
137
+
138
+ for i, box in enumerate(boxes.astype(int)):
139
+ cls_id = labels[i]
140
+ cls_name = classes[cls_id] if cls_id < len(classes) else str(cls_id)
141
+ color = CLASS_COLORS.get(cls_name, (0, 255, 0))
142
+
143
+ # Draw bounding box
144
+ draw.rectangle(box.tolist(), outline=color, width=3)
145
+
146
+ # Label text
147
+ label = f"{cls_name}"
148
+ if scores is not None:
149
+ label += f" {scores[i]:.2f}"
150
+
151
+ # Get text size
152
+ tw, th = draw.textbbox((0, 0), label, font=font)[2:]
153
+ tx, ty = box[0], max(0, box[1] - th - 4)
154
+
155
+ # Background rectangle
156
+ padding = 4
157
+ draw.rectangle([tx, ty, tx + tw + 2*padding, ty + th + 2*padding], fill=color)
158
+
159
+ # Put text
160
+ draw.text((tx + 2, ty + 2), label, fill="white", font=font)
161
+
162
+ return image
163
+
164
+ def run_inference(self, image, confidence_threshold=0.2, max_number_boxes=100):
165
+ """Run inference and return annotated PIL image.
166
+ Accepts PIL.Image directly.
167
+ """
168
+ if not isinstance(image, Image.Image):
169
+ raise ValueError("Input must be a PIL.Image")
170
+
171
+ origin_width, origin_height = image.size
172
+
173
+ # Preprocess
174
+ input_image = self._preprocess_image(image)
175
+
176
+ # Run model
177
+ input_name = self.ort_session.get_inputs()[0].name
178
+ outputs = self.ort_session.run(None, {input_name: input_image})
179
+
180
+ # Post-process
181
+ scores, labels, boxes = self._post_process(
182
+ outputs, origin_height, origin_width,
183
+ confidence_threshold, max_number_boxes
184
+ )
185
+
186
+ # Annotate and return
187
+ return self.annotate_detections(image.copy(), boxes, labels, scores)
188
+
189
+