File size: 4,452 Bytes
60b1e24
0c39d91
 
 
 
60b1e24
 
 
 
0c39d91
 
 
 
 
 
60b1e24
 
0c39d91
 
 
 
 
 
 
 
 
 
 
60b1e24
 
 
0c39d91
60b1e24
 
 
0c39d91
60b1e24
 
 
 
 
 
 
 
0c39d91
60b1e24
 
 
 
0c39d91
60b1e24
0c39d91
 
 
 
 
 
 
60b1e24
 
 
 
 
 
0c39d91
 
60b1e24
 
 
 
 
0c39d91
60b1e24
 
0c39d91
60b1e24
 
0c39d91
60b1e24
 
0c39d91
60b1e24
0c39d91
60b1e24
0c39d91
 
60b1e24
0c39d91
60b1e24
0c39d91
 
 
60b1e24
0c39d91
 
 
 
 
 
 
 
60b1e24
0c39d91
 
 
 
 
60b1e24
 
0c39d91
 
60b1e24
0c39d91
 
 
 
 
 
 
 
 
 
 
60b1e24
 
 
 
 
0c39d91
 
 
60b1e24
0c39d91
 
 
 
 
 
 
 
60b1e24
 
 
 
 
0c39d91
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# app.py
import logging
import os
import sys
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

# Configure logging FIRST
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))

# Import detector
try:
    from diagnosis.ai_engine.detect_stuttering import get_stutter_detector
    logger.info("βœ… Successfully imported StutterDetector")
except ImportError as e:
    logger.error(f"❌ Failed to import StutterDetector: {e}")
    raise

# Initialize FastAPI
app = FastAPI(
    title="Stutter Detector API",
    description="Speech analysis using Wav2Vec2 models for stutter detection",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global detector instance
detector = None

@app.on_event("startup")
async def startup_event():
    """Load models on startup"""
    global detector
    try:
        logger.info("πŸš€ Startup event: Loading AI models...")
        detector = get_stutter_detector()
        logger.info("βœ… Models loaded successfully!")
    except Exception as e:
        logger.error(f"❌ Failed to load models: {e}", exc_info=True)
        raise

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "models_loaded": detector is not None,
        "timestamp": str(os.popen("date").read()).strip()
    }

@app.post("/analyze")
async def analyze_audio(
    audio: UploadFile = File(...),
    transcript: str = ""
):
    """
    Analyze audio file for stuttering
    
    Parameters:
    - audio: WAV or MP3 audio file
    - transcript: Optional expected transcript
    
    Returns: Complete stutter analysis results
    """
    temp_file = None
    try:
        if not detector:
            raise HTTPException(status_code=503, detail="Models not loaded yet. Try again in a moment.")
        
        logger.info(f"πŸ“₯ Processing: {audio.filename}")
        
        # Create temp directory if needed
        temp_dir = "/tmp/stutter_analysis"
        os.makedirs(temp_dir, exist_ok=True)
        
        # Save uploaded file
        temp_file = os.path.join(temp_dir, audio.filename)
        content = await audio.read()
        
        with open(temp_file, "wb") as f:
            f.write(content)
        
        logger.info(f"πŸ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
        
        # Analyze
        logger.info(f"πŸ”„ Analyzing audio with transcript: '{transcript[:50]}...'")
        result = detector.analyze_audio(temp_file, transcript)
        
        logger.info(f"βœ… Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
        return result
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
    
    finally:
        # Cleanup
        if temp_file and os.path.exists(temp_file):
            try:
                os.remove(temp_file)
                logger.info(f"🧹 Cleaned up: {temp_file}")
            except Exception as e:
                logger.warning(f"Could not clean up {temp_file}: {e}")

@app.get("/")
async def root():
    """API documentation"""
    return {
        "name": "SLAQ Stutter Detector API",
        "version": "1.0.0",
        "status": "running",
        "endpoints": {
            "health": "GET /health",
            "analyze": "POST /analyze (multipart: audio file + optional transcript field)",
            "docs": "GET /docs (interactive API docs)"
        },
        "models": {
            "base": "facebook/wav2vec2-base-960h",
            "large": "facebook/wav2vec2-large-960h-lv60-self",
            "xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english"
        }
    }

if __name__ == "__main__":
    import uvicorn
    logger.info("πŸš€ Starting SLAQ Stutter Detector API...")
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info"
    )