File size: 12,256 Bytes
babf969 d0a607a babf969 d0a607a babf969 d0a607a babf969 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
"""
FastAPI app for defect detection using a TFLite model.
Provided endpoints
- GET / : Render an HTML form (no results).
- POST /predict/ : REST API. Predict defect on an uploaded image; returns JSON.
- POST /upload/ : Upload image, run prediction, and return an HTML page with visualization and results.
- POST /random-sample/ : Run prediction on a random sample image and return an HTML page with visualization and results.
Design notes
- This application is a demonstration / portfolio app. For simplicity and safety during demo runs, inference is performed synchronously using a single global TFLite Interpreter instance protected by a threading.Lock to ensure thread-safety.
- The code intentionally makes a number of fixed assumptions about the model and runtime. If the model or deployment requirements change, the corresponding preprocessing, postprocessing and runtime setup should be updated and tested.
Assumptions:
File system and assets
Font used for drawing labels: ./fonts/OpenSans-Bold.ttf
Static files served from: ./static
Directories (./static/uploads, ./static/results, ./static/samples) are expected to be present/created by the deployment (Dockerfile or startup); added an exist_ok mkdir as safeguard.
Upload / input constraints
Uploaded images are expected to be valid PNG images (this matches the local MVTec AD dataset used for development).
Maximum accepted upload size: 5 MB.
Runtime / model
Uses tflite-runtime Interpreter for model inference (Interpreter from tflite_runtime.interpreter).
TFLite model file path: ./final_model.tflite
Single Interpreter instance is created at startup and reused for all requests (protected by a threading.Lock).
Model I/O (these are the exact assumptions used by the code)
Expected input tensor: shape (1, 512, 512, 3), dtype float32, pixel value range [0, 255] (model handles internally normalization to [0, 1]).
Expected output[0]: segmentation mask of shape (1, 512, 512, 1), dtype float32, values in [0, 1] (probability map).
Expected output[1]: class probabilities of shape (1, 6), dtype float32 (softmax-like probabilities).
"""
# IMPORTS
# Standard library imports
import io
import logging
import os
import random
import time
import threading
# Third-party imports
from fastapi import FastAPI, File, UploadFile, Request, BackgroundTasks
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from PIL import Image, UnidentifiedImageError
from tflite_runtime.interpreter import Interpreter
# from ai_edge_litert.interpreter import Interpreter
# Auxiliary imports (Dockerfile sets CWD to /app)
from aux import preprocess_image, inference, save_image, draw_prediction, encode_mask_to_base64, delete_files_later
# START TIME LOGGING
import time
app_start = time.perf_counter()
# CONFIGURATION AND CONSTANTS
# Path to TFLite model file
MODEL_PATH = './final_model.tflite'
# Number of threads for TFLite interpreter
NUM_THREADS = 4
# Jinja2 templates directory
TEMPLATES = Jinja2Templates(directory='templates')
# Max file size for uploads (5 MB)
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB
# Max characters from the BASE64 mask to include in the HTML display
MAX_BASE64_DISPLAY = 10
# MAIN APPLICATION
# Set up logging to show INFO level and above messages
logging.basicConfig(level=logging.INFO)
# Initialize FastAPI app
app = FastAPI()
# Mount static files directory for serving images and other assets
# App will raise errors if folders do not exist
# Directory creation is handled by the Dockerfile
os.makedirs('static', exist_ok=True)
app.mount('/static', StaticFiles(directory='static'), name='static')
# Load model, set up interpreter and get input/output details
try:
interpreter = Interpreter(model_path=MODEL_PATH, num_threads=NUM_THREADS)
except:
logging.warning(f'num_threads={NUM_THREADS} not supported, falling back to single-threaded interpreter.')
interpreter = Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
logging.info('TF Lite input details: %s \n', input_details)
logging.info('TF Lite output details: %s \n', output_details)
# Create a threading lock for the interpreter to ensure thread-safety
interpreter_lock = threading.Lock()
# Inference context to be passed to inference function
inference_ctx = {
'interpreter_lock': interpreter_lock,
'interpreter': interpreter,
'input_details': input_details,
'output_details': output_details,
}
# Startup time measurement
@app.on_event('startup')
async def report_startup_time():
startup_time = (time.perf_counter() - app_start) * 1000 # in milliseconds
logging.info(f'App startup time: {startup_time:.2f} ms \n')
# Root endpoint to render the HTML form
@app.get('/', response_class=HTMLResponse)
async def root(request: Request):
# Render the HTML form with empty image URLs and no result
return TEMPLATES.TemplateResponse(
'index.html',
{
'request': request,
'result': None,
'orig_img_url': None,
'vis_img_url': None,
}
)
# Endpoint to handle image prediction (API)
@app.post('/predict/')
async def predict(file: UploadFile = File(...)):
try:
# Check if the uploaded file is a PNG image
if file.content_type != 'image/png':
return JSONResponse(status_code=400, content={'error': 'Only PNG images are supported.'})
# Read the image
image_bytes = await file.read()
# Check if the file size exceeds the maximum limit
if len(image_bytes) > MAX_FILE_SIZE:
return JSONResponse(status_code=400, content={'error': 'File size exceeds the maximum limit of 5 MB.'})
# Check if the image is a valid PNG (not just a file with .png extension)
try:
img_check = Image.open(io.BytesIO(image_bytes))
if img_check.format != 'PNG':
raise ValueError('Not a PNG')
except (UnidentifiedImageError, ValueError):
return JSONResponse(status_code=400, content={'error': 'Invalid image file.'})
# Preprocess the image
img = preprocess_image(image_bytes)
# Run inference on the preprocessed image
class_id, class_name, mask = inference(img, inference_ctx)
# Encode mask to base64
mask64 = encode_mask_to_base64(mask)
# Return the prediction results as JSON
return {
'class_id': class_id,
'class_name': class_name,
'mask64_PNG_L': mask64,
}
except Exception as e:
logging.exception(f'Error during prediction: {e}')
return JSONResponse(status_code=500, content={'error': 'Model inference failed.'})
# Endpoint to handle image upload and prediction with visualization
@app.post('/upload/', response_class=HTMLResponse)
async def upload(
request: Request,
file: UploadFile = File(...),
background_tasks: BackgroundTasks = None
):
try:
# Check if the uploaded file is a PNG image
if file.content_type != 'image/png':
result = {'error': 'Only PNG images are supported.'}
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': result})
# Read the uploaded image
image_bytes = await file.read()
# Check if the file size exceeds the maximum limit
if len(image_bytes) > MAX_FILE_SIZE:
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'File too large (max 5MB).'}})
# Check if the image is a valid PNG (not just a file with .png extension)
try:
img_check = Image.open(io.BytesIO(image_bytes))
if img_check.format != 'PNG':
raise ValueError('Not a PNG')
except (UnidentifiedImageError, ValueError):
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Invalid image file.'}})
# Save the preprocessed image
preproc_filename, preproc_path = save_image(image_bytes)
# Preprocess the image
img = preprocess_image(image_bytes)
# Run inference on the preprocessed image
class_id, class_name, mask = inference(img, inference_ctx)
# Overlay mask and draw class name on the preprocessed image for display
pred_filename, pred_path = draw_prediction(preproc_path, mask, class_name)
# Encode mask to base64
mask64 = encode_mask_to_base64(mask)
# Prepare the result to be displayed in the HTML template
result = {
'class_id': class_id,
'class_name': class_name,
'mask64_PNG_L': mask64[:MAX_BASE64_DISPLAY] + "...", # Truncated for HTML display
}
# Schedule deletion of both images after 10 seconds
if background_tasks is not None:
background_tasks.add_task(delete_files_later, [preproc_path, pred_path], delay=10)
# Render the HTML template with the result and image URLs
return TEMPLATES.TemplateResponse(
'index.html',
{
'request': request,
'result': result,
'preproc_img_url': f'/static/uploads/{preproc_filename}',
'pred_img_url': f'/static/results/{pred_filename}',
}
)
except Exception as e:
logging.exception(f'Error during prediction: {e}')
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Model inference failed.'}})
# Endpoint to handle random image (from samples) prediction with visualization
@app.post('/random-sample/', response_class=HTMLResponse)
async def random_sample(request: Request, background_tasks: BackgroundTasks = None):
try:
# Check if the samples directory exists and contains PNG files
samples_dir = 'static/samples'
sample_files = [f for f in os.listdir(samples_dir) if f.lower().endswith('.png')]
if not sample_files:
result = {'error': 'No sample images available.'}
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': result})
# Randomly select a sample image and read it
chosen_file = random.choice(sample_files)
with open(os.path.join(samples_dir, chosen_file), 'rb') as f:
image_bytes = f.read()
# Save preprocessed image
preproc_filename, preproc_path = save_image(image_bytes)
# Preprocess the image
img = preprocess_image(image_bytes)
# Run inference on the preprocessed image
class_id, class_name, mask = inference(img, inference_ctx)
# Overlay mask and draw class name on the preprocessed image for display
pred_filename, pred_path = draw_prediction(preproc_path, mask, class_name)
# Encode mask to base64
mask64 = encode_mask_to_base64(mask)
# Prepare the result to be displayed in the HTML template
result = {
'class_id': class_id,
'class_name': class_name,
'mask64_PNG_L': mask64[:MAX_BASE64_DISPLAY] + "...", # Truncated for HTML display
}
# Schedule deletion of both images after 10 seconds
if background_tasks is not None:
background_tasks.add_task(delete_files_later, [preproc_path, pred_path], delay=10)
# Render the HTML template with the result and image URLs
return TEMPLATES.TemplateResponse(
'index.html',
{
'request': request,
'result': result,
'preproc_img_url': f'/static/uploads/{preproc_filename}',
'pred_img_url': f'/static/results/{pred_filename}',
}
)
except Exception as e:
logging.exception(f'Error during prediction: {e}')
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Model inference failed.'}}) |