|
|
""" |
|
|
Module with auxiliary functions for the FastAPI defect detection demo app. |
|
|
|
|
|
Provided functions: |
|
|
- preprocess_image: Preprocess input image for model inference. |
|
|
- inference: Perform model inference and return class ID, class name, and segmentation mask. |
|
|
- encode_mask_to_base64: Encode segmentation mask to base64 string. |
|
|
- save_image: Save uploaded image bytes to a file. |
|
|
- draw_prediction: Save image with overlayed prediction mask and class name. |
|
|
|
|
|
Design notes |
|
|
|
|
|
- This application is a demonstration / portfolio app. For simplicity and safety during demo runs, inference is performed synchronously using a single global TFLite Interpreter instance protected by a threading.Lock to ensure thread-safety. |
|
|
- The code intentionally makes a number of fixed assumptions about the model and runtime. If the model or deployment requirements change, the corresponding preprocessing, postprocessing and runtime setup should be updated and tested. |
|
|
|
|
|
Assumptions: |
|
|
|
|
|
File system and assets |
|
|
Font used for drawing labels: ./fonts/OpenSans-Bold.ttf |
|
|
Static files served from: ./static |
|
|
Directories (./static/uploads, ./static/results, ./static/samples) are expected to be present/created by the deployment (Dockerfile or startup); added an exist_ok mkdir as safeguard. |
|
|
|
|
|
Upload / input constraints |
|
|
Uploaded images are expected to be valid PNG images (this matches the local MVTec AD dataset used for development). |
|
|
Maximum accepted upload size: 5 MB. |
|
|
|
|
|
Runtime / model |
|
|
Uses tflite-runtime Interpreter for model inference (Interpreter from tflite_runtime.interpreter). |
|
|
TFLite model file path: ./final_model.tflite |
|
|
Single Interpreter instance is created at startup and reused for all requests (protected by a threading.Lock). |
|
|
|
|
|
Model I/O (these are the exact assumptions used by the code) |
|
|
Expected input tensor: shape (1, 512, 512, 3), dtype float32, pixel value range [0, 255] (model handles internally normalization to [0, 1]). |
|
|
Expected output[0]: segmentation mask of shape (1, 512, 512, 1), dtype float32, values in [0, 1] (probability map). |
|
|
Expected output[1]: class probabilities of shape (1, 6), dtype float32 (softmax-like probabilities). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
|
import io |
|
|
import logging |
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FONT_PATH = './fonts/OpenSans-Bold.ttf' |
|
|
|
|
|
|
|
|
INPUT_IMAGE_SIZE = (512, 512) |
|
|
|
|
|
|
|
|
MAX_ALPHA = 100 |
|
|
MASK_COLOR = (0, 255, 255, 0) |
|
|
|
|
|
|
|
|
CLASS_MAP = { |
|
|
0: 'good', |
|
|
1: 'crack', |
|
|
2: 'faulty_imprint', |
|
|
3: 'poke', |
|
|
4: 'scratch', |
|
|
5: 'squeeze' |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image(image_bytes) -> np.ndarray: |
|
|
""" |
|
|
Preprocess the input image for model inference. |
|
|
Args: |
|
|
image_bytes: Raw bytes of the input image. |
|
|
Returns: |
|
|
Preprocessed image as a numpy array of shape (1, INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1], 3) and dtype float32. |
|
|
""" |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
image = image.resize(INPUT_IMAGE_SIZE) |
|
|
img_array = np.array(image, dtype=np.float32) |
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
return img_array |
|
|
|
|
|
|
|
|
def inference(img, inference_ctx) -> tuple[int, str, np.ndarray]: |
|
|
""" |
|
|
Perform model inference on the preprocessed image. |
|
|
Args: |
|
|
img: Preprocessed image as a numpy array. |
|
|
inference_ctx: Dictionary containing the threading lock and the interpreter and its details. |
|
|
Returns: |
|
|
Tuple containing: |
|
|
- class_id: Predicted class ID (int). |
|
|
- class_name: Predicted class name (str). |
|
|
- mask: Predicted segmentation mask as a numpy array. |
|
|
""" |
|
|
|
|
|
with inference_ctx['interpreter_lock']: |
|
|
|
|
|
|
|
|
inference_ctx['interpreter'].set_tensor(inference_ctx['input_details'][0]['index'], img) |
|
|
inference_ctx['interpreter'].invoke() |
|
|
|
|
|
pred_mask = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][0]['index']) |
|
|
pred_label_probs = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][1]['index']) |
|
|
|
|
|
|
|
|
pred_label = np.argmax(pred_label_probs, axis=1) |
|
|
class_id = int(pred_label[0]) |
|
|
class_name = CLASS_MAP.get(class_id, 'unknown') |
|
|
mask = pred_mask.squeeze() |
|
|
|
|
|
return class_id, class_name, mask |
|
|
|
|
|
|
|
|
def encode_mask_to_base64(mask_array) -> str: |
|
|
""" |
|
|
Encode the segmentation mask to a base64 string. |
|
|
Args: |
|
|
mask_array: Segmentation mask as a numpy array. |
|
|
Returns: |
|
|
Base64-encoded string of the mask image. |
|
|
""" |
|
|
mask = (mask_array * 255).astype(np.uint8) |
|
|
mask_img = Image.fromarray(mask, mode='L') |
|
|
buffer = io.BytesIO() |
|
|
mask_img.save(buffer, format='PNG') |
|
|
mask64 = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
return mask64 |
|
|
|
|
|
|
|
|
def save_image(image_bytes) -> tuple[str, str]: |
|
|
""" |
|
|
Save the uploaded image bytes to a file. |
|
|
Args: |
|
|
image_bytes: Raw bytes of the input image. |
|
|
Returns: |
|
|
Tuple containing: |
|
|
- filename: Name of the saved file (str). |
|
|
- path: Path to the saved file (str). |
|
|
""" |
|
|
filename = f'{uuid.uuid4().hex}.png' |
|
|
os.makedirs('static/uploads', exist_ok=True) |
|
|
path = f'static/uploads/{filename}' |
|
|
with open(path, 'wb') as f: |
|
|
f.write(image_bytes) |
|
|
return filename, path |
|
|
|
|
|
|
|
|
def draw_prediction(image_path, mask_array, class_name) -> tuple[str, str]: |
|
|
""" |
|
|
Save image with overlayed prediction mask and class name. |
|
|
Args: |
|
|
image_path: Path to the original image file. |
|
|
mask_array: Segmentation mask as a numpy array. |
|
|
class_name: Predicted class name (str). |
|
|
Returns: |
|
|
Tuple containing: |
|
|
- filename: Name of the saved file (str). |
|
|
- path: Path to the saved file (str). |
|
|
""" |
|
|
|
|
|
orig_img = Image.open(image_path).convert('RGB') |
|
|
mask = (mask_array * 255).astype(np.uint8) |
|
|
mask_img = Image.fromarray(mask, mode='L') |
|
|
if mask_img.size != orig_img.size: |
|
|
mask_img = mask_img.resize(orig_img.size, resample=Image.Resampling.BILINEAR) |
|
|
|
|
|
|
|
|
alpha_arr = (np.array(mask_img, dtype=np.float32) / 255.0 * float(MAX_ALPHA)).astype(np.uint8) |
|
|
alpha_img = Image.fromarray(alpha_arr, mode='L') |
|
|
overlay = Image.new('RGBA', orig_img.size, MASK_COLOR) |
|
|
overlay.putalpha(alpha_img) |
|
|
overlay_img = Image.alpha_composite(orig_img.convert('RGBA'), overlay).convert('RGB') |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(overlay_img) |
|
|
try: |
|
|
font = ImageFont.truetype(FONT_PATH, 35) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
draw.text((40, 40), class_name, fill='red', font=font) |
|
|
|
|
|
|
|
|
filename = f'{uuid.uuid4().hex}.png' |
|
|
os.makedirs('static/results', exist_ok=True) |
|
|
path = f'static/results/{filename}' |
|
|
overlay_img.save(path) |
|
|
|
|
|
return filename, path |
|
|
|
|
|
|
|
|
def delete_files_later(files, delay=10) -> None: |
|
|
""" |
|
|
Delete files after a specified delay. |
|
|
Args: |
|
|
files: List of file paths to delete. |
|
|
delay: Time in seconds to wait before deleting files (default is 10). |
|
|
""" |
|
|
def _del_files(): |
|
|
time.sleep(delay) |
|
|
for f in files: |
|
|
try: os.remove(f) |
|
|
except: logging.exception('Error deleting file %s', f) |
|
|
t = threading.Thread(target=_del_files, daemon=True) |
|
|
t.start() |
|
|
|
|
|
|