|
|
""" |
|
|
FastAPI app for defect detection using a TFLite model. |
|
|
|
|
|
Provided endpoints |
|
|
|
|
|
- GET / : Render an HTML form (no results). |
|
|
- POST /predict/ : REST API. Predict defect on an uploaded image; returns JSON. |
|
|
- POST /upload/ : Upload image, run prediction, and return an HTML page with visualization and results. |
|
|
- POST /random-sample/ : Run prediction on a random sample image and return an HTML page with visualization and results. |
|
|
|
|
|
Design notes |
|
|
|
|
|
- This application is a demonstration / portfolio app. For simplicity and safety during demo runs, inference is performed synchronously using a single global TFLite Interpreter instance protected by a threading.Lock to ensure thread-safety. |
|
|
- The code intentionally makes a number of fixed assumptions about the model and runtime. If the model or deployment requirements change, the corresponding preprocessing, postprocessing and runtime setup should be updated and tested. |
|
|
|
|
|
Assumptions: |
|
|
|
|
|
File system and assets |
|
|
Font used for drawing labels: ./fonts/OpenSans-Bold.ttf |
|
|
Static files served from: ./static |
|
|
Directories (./static/uploads, ./static/results, ./static/samples) are expected to be present/created by the deployment (Dockerfile or startup); added an exist_ok mkdir as safeguard. |
|
|
|
|
|
Upload / input constraints |
|
|
Uploaded images are expected to be valid PNG images (this matches the local MVTec AD dataset used for development). |
|
|
Maximum accepted upload size: 5 MB. |
|
|
|
|
|
Runtime / model |
|
|
Uses tflite-runtime Interpreter for model inference (Interpreter from tflite_runtime.interpreter). |
|
|
TFLite model file path: ./final_model.tflite |
|
|
Single Interpreter instance is created at startup and reused for all requests (protected by a threading.Lock). |
|
|
|
|
|
Model I/O (these are the exact assumptions used by the code) |
|
|
Expected input tensor: shape (1, 512, 512, 3), dtype float32, pixel value range [0, 255] (model handles internally normalization to [0, 1]). |
|
|
Expected output[0]: segmentation mask of shape (1, 512, 512, 1), dtype float32, values in [0, 1] (probability map). |
|
|
Expected output[1]: class probabilities of shape (1, 6), dtype float32 (softmax-like probabilities). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
import threading |
|
|
|
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, Request, BackgroundTasks |
|
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from PIL import Image, UnidentifiedImageError |
|
|
from tflite_runtime.interpreter import Interpreter |
|
|
|
|
|
|
|
|
|
|
|
from aux import preprocess_image, inference, save_image, draw_prediction, encode_mask_to_base64, delete_files_later |
|
|
|
|
|
|
|
|
import time |
|
|
app_start = time.perf_counter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = './final_model.tflite' |
|
|
|
|
|
|
|
|
NUM_THREADS = 4 |
|
|
|
|
|
|
|
|
TEMPLATES = Jinja2Templates(directory='templates') |
|
|
|
|
|
|
|
|
MAX_FILE_SIZE = 5 * 1024 * 1024 |
|
|
|
|
|
|
|
|
MAX_BASE64_DISPLAY = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs('static', exist_ok=True) |
|
|
app.mount('/static', StaticFiles(directory='static'), name='static') |
|
|
|
|
|
|
|
|
try: |
|
|
interpreter = Interpreter(model_path=MODEL_PATH, num_threads=NUM_THREADS) |
|
|
except: |
|
|
logging.warning(f'num_threads={NUM_THREADS} not supported, falling back to single-threaded interpreter.') |
|
|
interpreter = Interpreter(model_path=MODEL_PATH) |
|
|
interpreter.allocate_tensors() |
|
|
input_details = interpreter.get_input_details() |
|
|
output_details = interpreter.get_output_details() |
|
|
logging.info('TF Lite input details: %s \n', input_details) |
|
|
logging.info('TF Lite output details: %s \n', output_details) |
|
|
|
|
|
|
|
|
interpreter_lock = threading.Lock() |
|
|
|
|
|
|
|
|
inference_ctx = { |
|
|
'interpreter_lock': interpreter_lock, |
|
|
'interpreter': interpreter, |
|
|
'input_details': input_details, |
|
|
'output_details': output_details, |
|
|
} |
|
|
|
|
|
|
|
|
@app.on_event('startup') |
|
|
async def report_startup_time(): |
|
|
startup_time = (time.perf_counter() - app_start) * 1000 |
|
|
logging.info(f'App startup time: {startup_time:.2f} ms \n') |
|
|
|
|
|
|
|
|
@app.get('/', response_class=HTMLResponse) |
|
|
async def root(request: Request): |
|
|
|
|
|
return TEMPLATES.TemplateResponse( |
|
|
'index.html', |
|
|
{ |
|
|
'request': request, |
|
|
'result': None, |
|
|
'orig_img_url': None, |
|
|
'vis_img_url': None, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.post('/predict/') |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
try: |
|
|
|
|
|
if file.content_type != 'image/png': |
|
|
return JSONResponse(status_code=400, content={'error': 'Only PNG images are supported.'}) |
|
|
|
|
|
|
|
|
image_bytes = await file.read() |
|
|
|
|
|
|
|
|
if len(image_bytes) > MAX_FILE_SIZE: |
|
|
return JSONResponse(status_code=400, content={'error': 'File size exceeds the maximum limit of 5 MB.'}) |
|
|
|
|
|
|
|
|
try: |
|
|
img_check = Image.open(io.BytesIO(image_bytes)) |
|
|
if img_check.format != 'PNG': |
|
|
raise ValueError('Not a PNG') |
|
|
except (UnidentifiedImageError, ValueError): |
|
|
return JSONResponse(status_code=400, content={'error': 'Invalid image file.'}) |
|
|
|
|
|
|
|
|
img = preprocess_image(image_bytes) |
|
|
|
|
|
|
|
|
class_id, class_name, mask = inference(img, inference_ctx) |
|
|
|
|
|
|
|
|
mask64 = encode_mask_to_base64(mask) |
|
|
|
|
|
|
|
|
return { |
|
|
'class_id': class_id, |
|
|
'class_name': class_name, |
|
|
'mask64_PNG_L': mask64, |
|
|
} |
|
|
except Exception as e: |
|
|
logging.exception(f'Error during prediction: {e}') |
|
|
return JSONResponse(status_code=500, content={'error': 'Model inference failed.'}) |
|
|
|
|
|
|
|
|
@app.post('/upload/', response_class=HTMLResponse) |
|
|
async def upload( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
background_tasks: BackgroundTasks = None |
|
|
): |
|
|
try: |
|
|
|
|
|
if file.content_type != 'image/png': |
|
|
result = {'error': 'Only PNG images are supported.'} |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': result}) |
|
|
|
|
|
|
|
|
image_bytes = await file.read() |
|
|
|
|
|
|
|
|
if len(image_bytes) > MAX_FILE_SIZE: |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'File too large (max 5MB).'}}) |
|
|
|
|
|
|
|
|
try: |
|
|
img_check = Image.open(io.BytesIO(image_bytes)) |
|
|
if img_check.format != 'PNG': |
|
|
raise ValueError('Not a PNG') |
|
|
except (UnidentifiedImageError, ValueError): |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Invalid image file.'}}) |
|
|
|
|
|
|
|
|
preproc_filename, preproc_path = save_image(image_bytes) |
|
|
|
|
|
|
|
|
img = preprocess_image(image_bytes) |
|
|
|
|
|
|
|
|
class_id, class_name, mask = inference(img, inference_ctx) |
|
|
|
|
|
|
|
|
pred_filename, pred_path = draw_prediction(preproc_path, mask, class_name) |
|
|
|
|
|
|
|
|
mask64 = encode_mask_to_base64(mask) |
|
|
|
|
|
|
|
|
result = { |
|
|
'class_id': class_id, |
|
|
'class_name': class_name, |
|
|
'mask64_PNG_L': mask64[:MAX_BASE64_DISPLAY] + "...", |
|
|
} |
|
|
|
|
|
|
|
|
if background_tasks is not None: |
|
|
background_tasks.add_task(delete_files_later, [preproc_path, pred_path], delay=10) |
|
|
|
|
|
|
|
|
return TEMPLATES.TemplateResponse( |
|
|
'index.html', |
|
|
{ |
|
|
'request': request, |
|
|
'result': result, |
|
|
'preproc_img_url': f'/static/uploads/{preproc_filename}', |
|
|
'pred_img_url': f'/static/results/{pred_filename}', |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
logging.exception(f'Error during prediction: {e}') |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Model inference failed.'}}) |
|
|
|
|
|
|
|
|
@app.post('/random-sample/', response_class=HTMLResponse) |
|
|
async def random_sample(request: Request, background_tasks: BackgroundTasks = None): |
|
|
try: |
|
|
|
|
|
samples_dir = 'static/samples' |
|
|
sample_files = [f for f in os.listdir(samples_dir) if f.lower().endswith('.png')] |
|
|
if not sample_files: |
|
|
result = {'error': 'No sample images available.'} |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': result}) |
|
|
|
|
|
|
|
|
chosen_file = random.choice(sample_files) |
|
|
with open(os.path.join(samples_dir, chosen_file), 'rb') as f: |
|
|
image_bytes = f.read() |
|
|
|
|
|
|
|
|
preproc_filename, preproc_path = save_image(image_bytes) |
|
|
|
|
|
|
|
|
img = preprocess_image(image_bytes) |
|
|
|
|
|
|
|
|
class_id, class_name, mask = inference(img, inference_ctx) |
|
|
|
|
|
|
|
|
pred_filename, pred_path = draw_prediction(preproc_path, mask, class_name) |
|
|
|
|
|
|
|
|
mask64 = encode_mask_to_base64(mask) |
|
|
|
|
|
|
|
|
result = { |
|
|
'class_id': class_id, |
|
|
'class_name': class_name, |
|
|
'mask64_PNG_L': mask64[:MAX_BASE64_DISPLAY] + "...", |
|
|
} |
|
|
|
|
|
|
|
|
if background_tasks is not None: |
|
|
background_tasks.add_task(delete_files_later, [preproc_path, pred_path], delay=10) |
|
|
|
|
|
|
|
|
return TEMPLATES.TemplateResponse( |
|
|
'index.html', |
|
|
{ |
|
|
'request': request, |
|
|
'result': result, |
|
|
'preproc_img_url': f'/static/uploads/{preproc_filename}', |
|
|
'pred_img_url': f'/static/results/{pred_filename}', |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
logging.exception(f'Error during prediction: {e}') |
|
|
return TEMPLATES.TemplateResponse('index.html', {'request': request, 'result': {'error': 'Model inference failed.'}}) |