Kev-HL's picture
Minimal clone for deployment, see README for full project
babf969
"""
Module with auxiliary functions for the FastAPI defect detection demo app.
Provided functions:
- preprocess_image: Preprocess input image for model inference.
- inference: Perform model inference and return class ID, class name, and segmentation mask.
- encode_mask_to_base64: Encode segmentation mask to base64 string.
- save_image: Save uploaded image bytes to a file.
- draw_prediction: Save image with overlayed prediction mask and class name.
Design notes
- This application is a demonstration / portfolio app. For simplicity and safety during demo runs, inference is performed synchronously using a single global TFLite Interpreter instance protected by a threading.Lock to ensure thread-safety.
- The code intentionally makes a number of fixed assumptions about the model and runtime. If the model or deployment requirements change, the corresponding preprocessing, postprocessing and runtime setup should be updated and tested.
Assumptions:
File system and assets
Font used for drawing labels: ./fonts/OpenSans-Bold.ttf
Static files served from: ./static
Directories (./static/uploads, ./static/results, ./static/samples) are expected to be present/created by the deployment (Dockerfile or startup); added an exist_ok mkdir as safeguard.
Upload / input constraints
Uploaded images are expected to be valid PNG images (this matches the local MVTec AD dataset used for development).
Maximum accepted upload size: 5 MB.
Runtime / model
Uses tflite-runtime Interpreter for model inference (Interpreter from tflite_runtime.interpreter).
TFLite model file path: ./final_model.tflite
Single Interpreter instance is created at startup and reused for all requests (protected by a threading.Lock).
Model I/O (these are the exact assumptions used by the code)
Expected input tensor: shape (1, 512, 512, 3), dtype float32, pixel value range [0, 255] (model handles internally normalization to [0, 1]).
Expected output[0]: segmentation mask of shape (1, 512, 512, 1), dtype float32, values in [0, 1] (probability map).
Expected output[1]: class probabilities of shape (1, 6), dtype float32 (softmax-like probabilities).
"""
# IMPORTS
# Standard library imports
import base64
import io
import logging
import os
import threading
import time
import uuid
# Third-party imports
import numpy as np
from PIL import Image, ImageDraw, ImageFont
# CONFIGURATION AND CONSTANTS
# Font path for drawing text on images
FONT_PATH = './fonts/OpenSans-Bold.ttf'
# Input image size for the model
INPUT_IMAGE_SIZE = (512, 512)
# Transparency level and color for mask overlay
MAX_ALPHA = 100 # [0-255]
MASK_COLOR = (0, 255, 255, 0) # Cyan RGB (R,G,B,A)
# Dictionary mapping class IDs to class names
CLASS_MAP = {
0: 'good',
1: 'crack',
2: 'faulty_imprint',
3: 'poke',
4: 'scratch',
5: 'squeeze'
}
# AUXILIARY FUNCTIONS FOR main.py
# Function to preprocess the image
def preprocess_image(image_bytes) -> np.ndarray:
"""
Preprocess the input image for model inference.
Args:
image_bytes: Raw bytes of the input image.
Returns:
Preprocessed image as a numpy array of shape (1, INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1], 3) and dtype float32.
"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = image.resize(INPUT_IMAGE_SIZE)
img_array = np.array(image, dtype=np.float32)
img_array = np.expand_dims(img_array, axis=0)
return img_array
# Function to perform inference on a preprocessed image
def inference(img, inference_ctx) -> tuple[int, str, np.ndarray]:
"""
Perform model inference on the preprocessed image.
Args:
img: Preprocessed image as a numpy array.
inference_ctx: Dictionary containing the threading lock and the interpreter and its details.
Returns:
Tuple containing:
- class_id: Predicted class ID (int).
- class_name: Predicted class name (str).
- mask: Predicted segmentation mask as a numpy array.
"""
# Ensure the interpreter is thread-safe
with inference_ctx['interpreter_lock']:
# Set the input tensor and invoke the interpreter
inference_ctx['interpreter'].set_tensor(inference_ctx['input_details'][0]['index'], img)
inference_ctx['interpreter'].invoke()
# Get the prediction results
pred_mask = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][0]['index'])
pred_label_probs = inference_ctx['interpreter'].get_tensor(inference_ctx['output_details'][1]['index'])
# Format the prediction results and get the class name
pred_label = np.argmax(pred_label_probs, axis=1)
class_id = int(pred_label[0])
class_name = CLASS_MAP.get(class_id, 'unknown')
mask = pred_mask.squeeze()
return class_id, class_name, mask
# Function to encode mask to base64
def encode_mask_to_base64(mask_array) -> str:
"""
Encode the segmentation mask to a base64 string.
Args:
mask_array: Segmentation mask as a numpy array.
Returns:
Base64-encoded string of the mask image.
"""
mask = (mask_array * 255).astype(np.uint8)
mask_img = Image.fromarray(mask, mode='L')
buffer = io.BytesIO()
mask_img.save(buffer, format='PNG')
mask64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return mask64
# Function to save an image for later use
def save_image(image_bytes) -> tuple[str, str]:
"""
Save the uploaded image bytes to a file.
Args:
image_bytes: Raw bytes of the input image.
Returns:
Tuple containing:
- filename: Name of the saved file (str).
- path: Path to the saved file (str).
"""
filename = f'{uuid.uuid4().hex}.png'
os.makedirs('static/uploads', exist_ok=True)
path = f'static/uploads/{filename}'
with open(path, 'wb') as f:
f.write(image_bytes)
return filename, path
# Function to save image with overlayed prediction mask and class name
def draw_prediction(image_path, mask_array, class_name) -> tuple[str, str]:
"""
Save image with overlayed prediction mask and class name.
Args:
image_path: Path to the original image file.
mask_array: Segmentation mask as a numpy array.
class_name: Predicted class name (str).
Returns:
Tuple containing:
- filename: Name of the saved file (str).
- path: Path to the saved file (str).
"""
# Load the original image and mask
orig_img = Image.open(image_path).convert('RGB')
mask = (mask_array * 255).astype(np.uint8)
mask_img = Image.fromarray(mask, mode='L')
if mask_img.size != orig_img.size:
mask_img = mask_img.resize(orig_img.size, resample=Image.Resampling.BILINEAR)
# Overlay the mask on the original image with some transparency
alpha_arr = (np.array(mask_img, dtype=np.float32) / 255.0 * float(MAX_ALPHA)).astype(np.uint8)
alpha_img = Image.fromarray(alpha_arr, mode='L')
overlay = Image.new('RGBA', orig_img.size, MASK_COLOR)
overlay.putalpha(alpha_img)
overlay_img = Image.alpha_composite(orig_img.convert('RGBA'), overlay).convert('RGB')
# Draw the class name on the image
draw = ImageDraw.Draw(overlay_img)
try:
font = ImageFont.truetype(FONT_PATH, 35)
except:
font = ImageFont.load_default()
draw.text((40, 40), class_name, fill='red', font=font)
# Save the visualization image (with bounding box and label)
filename = f'{uuid.uuid4().hex}.png'
os.makedirs('static/results', exist_ok=True)
path = f'static/results/{filename}'
overlay_img.save(path)
return filename, path
# Function to delete files after a delay
def delete_files_later(files, delay=10) -> None:
"""
Delete files after a specified delay.
Args:
files: List of file paths to delete.
delay: Time in seconds to wait before deleting files (default is 10).
"""
def _del_files():
time.sleep(delay)
for f in files:
try: os.remove(f)
except: logging.exception('Error deleting file %s', f)
t = threading.Thread(target=_del_files, daemon=True)
t.start()