File size: 4,452 Bytes
60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 60b1e24 0c39d91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# app.py
import logging
import os
import sys
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# Configure logging FIRST
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
# Import detector
try:
from diagnosis.ai_engine.detect_stuttering import get_stutter_detector
logger.info("β
Successfully imported StutterDetector")
except ImportError as e:
logger.error(f"β Failed to import StutterDetector: {e}")
raise
# Initialize FastAPI
app = FastAPI(
title="Stutter Detector API",
description="Speech analysis using Wav2Vec2 models for stutter detection",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global detector instance
detector = None
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global detector
try:
logger.info("π Startup event: Loading AI models...")
detector = get_stutter_detector()
logger.info("β
Models loaded successfully!")
except Exception as e:
logger.error(f"β Failed to load models: {e}", exc_info=True)
raise
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": detector is not None,
"timestamp": str(os.popen("date").read()).strip()
}
@app.post("/analyze")
async def analyze_audio(
audio: UploadFile = File(...),
transcript: str = ""
):
"""
Analyze audio file for stuttering
Parameters:
- audio: WAV or MP3 audio file
- transcript: Optional expected transcript
Returns: Complete stutter analysis results
"""
temp_file = None
try:
if not detector:
raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.")
logger.info(f"π₯ Processing: {audio.filename}")
# Create temp directory if needed
temp_dir = "/tmp/stutter_analysis"
os.makedirs(temp_dir, exist_ok=True)
# Save uploaded file
temp_file = os.path.join(temp_dir, audio.filename)
content = await audio.read()
with open(temp_file, "wb") as f:
f.write(content)
logger.info(f"π Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
# Analyze
logger.info(f"π Analyzing audio with transcript: '{transcript[:50]}...'")
result = detector.analyze_audio(temp_file, transcript)
logger.info(f"β
Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"β Error during analysis: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
finally:
# Cleanup
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
logger.info(f"π§Ή Cleaned up: {temp_file}")
except Exception as e:
logger.warning(f"Could not clean up {temp_file}: {e}")
@app.get("/")
async def root():
"""API documentation"""
return {
"name": "SLAQ Stutter Detector API",
"version": "1.0.0",
"status": "running",
"endpoints": {
"health": "GET /health",
"analyze": "POST /analyze (multipart: audio file + optional transcript field)",
"docs": "GET /docs (interactive API docs)"
},
"models": {
"base": "facebook/wav2vec2-base-960h",
"large": "facebook/wav2vec2-large-960h-lv60-self",
"xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english"
}
}
if __name__ == "__main__":
import uvicorn
logger.info("π Starting SLAQ Stutter Detector API...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
) |