|
|
|
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
stream=sys.stdout |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
|
|
|
try: |
|
|
from diagnosis.ai_engine.detect_stuttering import get_stutter_detector |
|
|
logger.info("β
Successfully imported StutterDetector") |
|
|
except ImportError as e: |
|
|
logger.error(f"β Failed to import StutterDetector: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Stutter Detector API", |
|
|
description="Speech analysis using Wav2Vec2 models for stutter detection", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
detector = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load models on startup""" |
|
|
global detector |
|
|
try: |
|
|
logger.info("π Startup event: Loading AI models...") |
|
|
detector = get_stutter_detector() |
|
|
logger.info("β
Models loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load models: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"models_loaded": detector is not None, |
|
|
"timestamp": str(os.popen("date").read()).strip() |
|
|
} |
|
|
|
|
|
@app.post("/analyze") |
|
|
async def analyze_audio( |
|
|
audio: UploadFile = File(...), |
|
|
transcript: str = "" |
|
|
): |
|
|
""" |
|
|
Analyze audio file for stuttering |
|
|
|
|
|
Parameters: |
|
|
- audio: WAV or MP3 audio file |
|
|
- transcript: Optional expected transcript |
|
|
|
|
|
Returns: Complete stutter analysis results |
|
|
""" |
|
|
temp_file = None |
|
|
try: |
|
|
if not detector: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.") |
|
|
|
|
|
logger.info(f"π₯ Processing: {audio.filename}") |
|
|
|
|
|
|
|
|
temp_dir = "/tmp/stutter_analysis" |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
temp_file = os.path.join(temp_dir, audio.filename) |
|
|
content = await audio.read() |
|
|
|
|
|
with open(temp_file, "wb") as f: |
|
|
f.write(content) |
|
|
|
|
|
logger.info(f"π Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)") |
|
|
|
|
|
|
|
|
logger.info(f"π Analyzing audio with transcript: '{transcript[:50]}...'") |
|
|
result = detector.analyze_audio(temp_file, transcript) |
|
|
|
|
|
logger.info(f"β
Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%") |
|
|
return result |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"β Error during analysis: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
|
|
|
|
|
finally: |
|
|
|
|
|
if temp_file and os.path.exists(temp_file): |
|
|
try: |
|
|
os.remove(temp_file) |
|
|
logger.info(f"π§Ή Cleaned up: {temp_file}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not clean up {temp_file}: {e}") |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""API documentation""" |
|
|
return { |
|
|
"name": "SLAQ Stutter Detector API", |
|
|
"version": "1.0.0", |
|
|
"status": "running", |
|
|
"endpoints": { |
|
|
"health": "GET /health", |
|
|
"analyze": "POST /analyze (multipart: audio file + optional transcript field)", |
|
|
"docs": "GET /docs (interactive API docs)" |
|
|
}, |
|
|
"models": { |
|
|
"base": "facebook/wav2vec2-base-960h", |
|
|
"large": "facebook/wav2vec2-large-960h-lv60-self", |
|
|
"xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english" |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
logger.info("π Starting SLAQ Stutter Detector API...") |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
log_level="info" |
|
|
) |