File size: 8,315 Bytes
babf969 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
"""
Module with auxiliary functions for the FastAPI defect detection demo app.
Provided functions:
- preprocess_image: Preprocess input image for model inference.
- inference: Perform model inference and return class ID, class name, and segmentation mask.
- encode_mask_to_base64: Encode segmentation mask to base64 string.
- save_image: Save uploaded image bytes to a file.
- draw_prediction: Save image with overlayed prediction mask and class name.
Design notes
- This application is a demonstration / portfolio app. For simplicity and safety during demo runs, inference is performed synchronously using a single global TFLite Interpreter instance protected by a threading.Lock to ensure thread-safety.
- The code intentionally makes a number of fixed assumptions about the model and runtime. If the model or deployment requirements change, the corresponding preprocessing, postprocessing and runtime setup should be updated and tested.
Assumptions:
File system and assets
Font used for drawing labels: ./fonts/OpenSans-Bold.ttf
Static files served from: ./static
Directories (./static/uploads, ./static/results, ./static/samples) are expected to be present/created by the deployment (Dockerfile or startup); added an exist_ok mkdir as safeguard.
Upload / input constraints
Uploaded images are expected to be valid PNG images (this matches the local MVTec AD dataset used for development).
Maximum accepted upload size: 5 MB.
Runtime / model
Uses tflite-runtime Interpreter for model inference (Interpreter from tflite_runtime.interpreter).
TFLite model file path: ./final_model.tflite
Single Interpreter instance is created at startup and reused for all requests (protected by a threading.Lock).
Model I/O (these are the exact assumptions used by the code)
Expected input tensor: shape (1, 512, 512, 3), dtype float32, pixel value range [0, 255] (model handles internally normalization to [0, 1]).
Expected output[0]: segmentation mask of shape (1, 512, 512, 1), dtype float32, values in [0, 1] (probability map).
Expected output[1]: class probabilities of shape (1, 6), dtype float32 (softmax-like probabilities).
"""
# IMPORTS
# Standard library imports
import base64
import io
import logging
import os
import threading
import time
import uuid
# Third-party imports
import numpy as np
from PIL import Image, ImageDraw, ImageFont
# CONFIGURATION AND CONSTANTS
# Font path for drawing text on images
FONT_PATH = './fonts/OpenSans-Bold.ttf'
# Input image size for the model
INPUT_IMAGE_SIZE = (512, 512)
# Transparency level and color for mask overlay
MAX_ALPHA = 100 # [0-255]
MASK_COLOR = (0, 255, 255, 0) # Cyan RGB (R,G,B,A)
# Dictionary mapping class IDs to class names
CLASS_MAP = {
0: 'good',
1: 'crack',
2: 'faulty_imprint',
3: 'poke',
4: 'scratch',
5: 'squeeze'
}
# AUXILIARY FUNCTIONS FOR main.py
# Function to preprocess the image
def preprocess_image(image_bytes) -> np.ndarray:
"""
Preprocess the input image for model inference.
Args:
image_bytes: Raw bytes of the input image.
Returns:
Preprocessed image as a numpy array of shape (1, INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1], 3) and dtype float32.
"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = image.resize(INPUT_IMAGE_SIZE)
img_array = np.array(image, dtype=np.float32)
img_array = np.expand_dims(img_array, axis=0)
return img_array
# Function to perform inference on a preprocessed image
def inference(img, inference_ctx) -> tuple[int, str, np.ndarray]:
"""
Perform model inference on the preprocessed image.
Args:
img: Preprocessed image as a numpy array.
inference_ctx: Dictionary containing the threading lock and the interpreter and its details.
Returns:
Tuple containing:
- class_id: Predicted class ID (int).
- class_name: Predicted class name (str).
- mask: Predicted segmentation mask as a numpy array.
"""
# Ensure the interpreter is thread-safe
with inference_ctx['interpreter_lock']:
# Set the input tensor and invoke the interpreter
inference_ctx['interpreter'].set_tensor(inference_ctx['input_details'][0]['index'], img)
inference_ctx['interpreter'].invoke()
# Get the prediction results
pred_mask = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][0]['index'])
pred_label_probs = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][1]['index'])
# Format the prediction results and get the class name
pred_label = np.argmax(pred_label_probs, axis=1)
class_id = int(pred_label[0])
class_name = CLASS_MAP.get(class_id, 'unknown')
mask = pred_mask.squeeze()
return class_id, class_name, mask
# Function to encode mask to base64
def encode_mask_to_base64(mask_array) -> str:
"""
Encode the segmentation mask to a base64 string.
Args:
mask_array: Segmentation mask as a numpy array.
Returns:
Base64-encoded string of the mask image.
"""
mask = (mask_array * 255).astype(np.uint8)
mask_img = Image.fromarray(mask, mode='L')
buffer = io.BytesIO()
mask_img.save(buffer, format='PNG')
mask64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return mask64
# Function to save an image for later use
def save_image(image_bytes) -> tuple[str, str]:
"""
Save the uploaded image bytes to a file.
Args:
image_bytes: Raw bytes of the input image.
Returns:
Tuple containing:
- filename: Name of the saved file (str).
- path: Path to the saved file (str).
"""
filename = f'{uuid.uuid4().hex}.png'
os.makedirs('static/uploads', exist_ok=True)
path = f'static/uploads/{filename}'
with open(path, 'wb') as f:
f.write(image_bytes)
return filename, path
# Function to save image with overlayed prediction mask and class name
def draw_prediction(image_path, mask_array, class_name) -> tuple[str, str]:
"""
Save image with overlayed prediction mask and class name.
Args:
image_path: Path to the original image file.
mask_array: Segmentation mask as a numpy array.
class_name: Predicted class name (str).
Returns:
Tuple containing:
- filename: Name of the saved file (str).
- path: Path to the saved file (str).
"""
# Load the original image and mask
orig_img = Image.open(image_path).convert('RGB')
mask = (mask_array * 255).astype(np.uint8)
mask_img = Image.fromarray(mask, mode='L')
if mask_img.size != orig_img.size:
mask_img = mask_img.resize(orig_img.size, resample=Image.Resampling.BILINEAR)
# Overlay the mask on the original image with some transparency
alpha_arr = (np.array(mask_img, dtype=np.float32) / 255.0 * float(MAX_ALPHA)).astype(np.uint8)
alpha_img = Image.fromarray(alpha_arr, mode='L')
overlay = Image.new('RGBA', orig_img.size, MASK_COLOR)
overlay.putalpha(alpha_img)
overlay_img = Image.alpha_composite(orig_img.convert('RGBA'), overlay).convert('RGB')
# Draw the class name on the image
draw = ImageDraw.Draw(overlay_img)
try:
font = ImageFont.truetype(FONT_PATH, 35)
except:
font = ImageFont.load_default()
draw.text((40, 40), class_name, fill='red', font=font)
# Save the visualization image (with bounding box and label)
filename = f'{uuid.uuid4().hex}.png'
os.makedirs('static/results', exist_ok=True)
path = f'static/results/{filename}'
overlay_img.save(path)
return filename, path
# Function to delete files after a delay
def delete_files_later(files, delay=10) -> None:
"""
Delete files after a specified delay.
Args:
files: List of file paths to delete.
delay: Time in seconds to wait before deleting files (default is 10).
"""
def _del_files():
time.sleep(delay)
for f in files:
try: os.remove(f)
except: logging.exception('Error deleting file %s', f)
t = threading.Thread(target=_del_files, daemon=True)
t.start()
|