anfastech's picture
Fix: numpy dependency order, add proper error handling
0c39d91
# app.py
import logging
import os
import sys
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# Configure logging FIRST
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
# Import detector
try:
from diagnosis.ai_engine.detect_stuttering import get_stutter_detector
logger.info("βœ… Successfully imported StutterDetector")
except ImportError as e:
logger.error(f"❌ Failed to import StutterDetector: {e}")
raise
# Initialize FastAPI
app = FastAPI(
title="Stutter Detector API",
description="Speech analysis using Wav2Vec2 models for stutter detection",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global detector instance
detector = None
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global detector
try:
logger.info("πŸš€ Startup event: Loading AI models...")
detector = get_stutter_detector()
logger.info("βœ… Models loaded successfully!")
except Exception as e:
logger.error(f"❌ Failed to load models: {e}", exc_info=True)
raise
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": detector is not None,
"timestamp": str(os.popen("date").read()).strip()
}
@app.post("/analyze")
async def analyze_audio(
audio: UploadFile = File(...),
transcript: str = ""
):
"""
Analyze audio file for stuttering
Parameters:
- audio: WAV or MP3 audio file
- transcript: Optional expected transcript
Returns: Complete stutter analysis results
"""
temp_file = None
try:
if not detector:
raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.")
logger.info(f"πŸ“₯ Processing: {audio.filename}")
# Create temp directory if needed
temp_dir = "/tmp/stutter_analysis"
os.makedirs(temp_dir, exist_ok=True)
# Save uploaded file
temp_file = os.path.join(temp_dir, audio.filename)
content = await audio.read()
with open(temp_file, "wb") as f:
f.write(content)
logger.info(f"πŸ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
# Analyze
logger.info(f"πŸ”„ Analyzing audio with transcript: '{transcript[:50]}...'")
result = detector.analyze_audio(temp_file, transcript)
logger.info(f"βœ… Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
finally:
# Cleanup
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
logger.info(f"🧹 Cleaned up: {temp_file}")
except Exception as e:
logger.warning(f"Could not clean up {temp_file}: {e}")
@app.get("/")
async def root():
"""API documentation"""
return {
"name": "SLAQ Stutter Detector API",
"version": "1.0.0",
"status": "running",
"endpoints": {
"health": "GET /health",
"analyze": "POST /analyze (multipart: audio file + optional transcript field)",
"docs": "GET /docs (interactive API docs)"
},
"models": {
"base": "facebook/wav2vec2-base-960h",
"large": "facebook/wav2vec2-large-960h-lv60-self",
"xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english"
}
}
if __name__ == "__main__":
import uvicorn
logger.info("πŸš€ Starting SLAQ Stutter Detector API...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)