anfastech commited on
Commit
fb4e8a7
·
1 Parent(s): 60b1e24

#2 fix stutterdectector

Browse files
diagnosis/ai_engine/detect_stuttering.py CHANGED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diagnosis/ai_engine/detect_stuttering.py
2
+ import librosa
3
+ import torch
4
+ import torchaudio
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
+ from torch.nn import CTCLoss
7
+ import logging
8
+ from typing import Dict, List, Tuple
9
+ import time
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class StutterDetector:
15
+ """
16
+ Stutter detection using Wav2Vec2 models
17
+ Adapted from: https://github.com/wittyicon29/Stutter_Detection
18
+ """
19
+
20
+ def __init__(self):
21
+ """Initialize models - load once and reuse"""
22
+ logger.info("🔄 Initializing StutterDetector models...")
23
+
24
+ try:
25
+ # Log model loading source
26
+ import os
27
+ hf_cache = os.environ.get('HF_HOME') or os.environ.get('TRANSFORMERS_CACHE')
28
+ if hf_cache:
29
+ logger.info(f"📂 Custom Hugging Face cache: {hf_cache}")
30
+ else:
31
+ home = os.path.expanduser('~')
32
+ default_cache = os.path.join(home, '.cache', 'huggingface')
33
+ logger.info(f"📂 Default Hugging Face cache: {default_cache}")
34
+
35
+ # Load base model for transcription
36
+ logger.info("📥 Loading base model: facebook/wav2vec2-base-960h")
37
+ self.base_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
38
+ self.base_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
39
+ logger.info("✅ Base model loaded successfully")
40
+
41
+ # Load large model for detailed analysis
42
+ logger.info("📥 Loading large model: facebook/wav2vec2-large-960h-lv60-self")
43
+ self.large_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
44
+ self.large_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
45
+ logger.info("✅ Large model loaded successfully")
46
+
47
+ # Load XLSR model for target transcript generation
48
+ logger.info("📥 Loading XLSR model: jonatasgrosman/wav2vec2-large-xlsr-53-english")
49
+ self.xlsr_model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
50
+ self.xlsr_processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
51
+ logger.info("✅ XLSR model loaded successfully")
52
+
53
+ logger.info("✅ All models loaded successfully")
54
+
55
+ except Exception as e:
56
+ logger.error(f"❌ Model loading failed: {e}")
57
+ raise
58
+
59
+
60
+ def analyze_audio(self, audio_file_path: str, proper_transcript: str = "") -> Dict:
61
+ """
62
+ Complete analysis pipeline
63
+
64
+ Args:
65
+ audio_file_path: Path to audio file
66
+ proper_transcript: Optional expected transcript (if available)
67
+
68
+ Returns:
69
+ Dictionary with complete analysis results
70
+ """
71
+ start_time = time.time()
72
+
73
+ try:
74
+ logger.info(f"🎯 Starting analysis for: {audio_file_path}")
75
+
76
+ # Step 1: Generate target transcript if not provided
77
+ if not proper_transcript:
78
+ proper_transcript = self.generate_target_transcript(audio_file_path)
79
+ logger.info(f"📝 Generated target transcript: {proper_transcript}")
80
+
81
+ proper_transcript = proper_transcript.upper()
82
+
83
+ # Step 2: Transcribe and detect stuttering
84
+ transcription_result = self.transcribe_and_detect(audio_file_path, proper_transcript)
85
+
86
+ # Step 3: Calculate CTC loss and find stutter timestamps
87
+ ctc_loss, stutter_timestamps = self.calculate_stutter_timestamps(
88
+ audio_file_path,
89
+ proper_transcript
90
+ )
91
+
92
+ # Step 4: Aggregate results
93
+ analysis_duration = time.time() - start_time
94
+
95
+ result = {
96
+ 'actual_transcript': transcription_result['transcription'],
97
+ 'target_transcript': proper_transcript,
98
+ 'mismatched_chars': transcription_result['stuttered_chars'],
99
+ 'mismatch_percentage': transcription_result['mismatch_percentage'],
100
+ 'ctc_loss_score': ctc_loss,
101
+ 'stutter_timestamps': stutter_timestamps,
102
+ 'total_stutter_duration': self._calculate_total_duration(stutter_timestamps),
103
+ 'stutter_frequency': self._calculate_frequency(stutter_timestamps, audio_file_path),
104
+ 'severity': self._determine_severity(transcription_result['mismatch_percentage']),
105
+ 'confidence_score': self._calculate_confidence(transcription_result, ctc_loss),
106
+ 'analysis_duration_seconds': round(analysis_duration, 2),
107
+ 'model_version': 'wav2vec2-base-960h',
108
+ }
109
+
110
+ logger.info(f"✅ Analysis complete in {analysis_duration:.2f}s")
111
+ return result
112
+
113
+ except Exception as e:
114
+ logger.error(f"❌ Analysis failed: {e}")
115
+ raise
116
+
117
+
118
+ def generate_target_transcript(self, audio_file: str) -> str:
119
+ """Generate expected transcript using XLSR model"""
120
+ try:
121
+ waveform, sample_rate = torchaudio.load(audio_file)
122
+
123
+ if sample_rate != 16000:
124
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
125
+ waveform = resampler(waveform)
126
+
127
+ input_values = self.xlsr_processor(waveform[0], return_tensors="pt").input_values
128
+
129
+ with torch.no_grad():
130
+ logits = self.xlsr_model(input_values).logits
131
+
132
+ predicted_ids = torch.argmax(logits, dim=-1)
133
+ predicted_sentences = self.xlsr_processor.batch_decode(predicted_ids)
134
+
135
+ return predicted_sentences[0]
136
+
137
+ except Exception as e:
138
+ logger.error(f"Target transcript generation failed: {e}")
139
+ raise
140
+
141
+
142
+ def transcribe_and_detect(self, audio_file: str, proper_transcript: str) -> Dict:
143
+ """Transcribe audio and detect stuttering patterns"""
144
+ try:
145
+ # Load audio
146
+ input_audio, _ = librosa.load(audio_file, sr=16000)
147
+
148
+ # Tokenize
149
+ input_features = self.base_processor(input_audio, return_tensors="pt").input_values
150
+
151
+ # Get predictions
152
+ with torch.no_grad():
153
+ logits = self.base_model(input_features).logits
154
+
155
+ # Decode
156
+ predicted_ids = torch.argmax(logits, dim=-1)
157
+ transcription = self.base_processor.batch_decode(predicted_ids)[0]
158
+
159
+ # Find stuttered sequences
160
+ stuttered_chars = self.find_sequences_not_in_common(transcription, proper_transcript)
161
+
162
+ # Calculate mismatch percentage
163
+ total_mismatched = sum(len(segment) for segment in stuttered_chars)
164
+ mismatch_percentage = (total_mismatched / len(proper_transcript)) * 100 if len(proper_transcript) > 0 else 0
165
+ mismatch_percentage = min(round(mismatch_percentage), 100)
166
+
167
+ return {
168
+ 'transcription': transcription,
169
+ 'stuttered_chars': stuttered_chars,
170
+ 'mismatch_percentage': mismatch_percentage
171
+ }
172
+
173
+ except Exception as e:
174
+ logger.error(f"Transcription failed: {e}")
175
+ raise
176
+
177
+
178
+ def calculate_stutter_timestamps(self, audio_file: str, proper_transcript: str) -> Tuple[float, List[Tuple[float, float]]]:
179
+ """Calculate CTC loss and find exact stutter timestamps"""
180
+ try:
181
+ # Load waveform
182
+ waveform, sample_rate = torchaudio.load(audio_file)
183
+
184
+ if sample_rate != 16000:
185
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
186
+ waveform = resampler(waveform)
187
+
188
+ # Process with base model for CTC loss
189
+ input_values = self.base_processor(waveform[0], return_tensors="pt").input_values
190
+
191
+ with torch.no_grad():
192
+ logits = self.base_model(input_values).logits
193
+
194
+ # Calculate CTC loss
195
+ tokens = self.base_processor.tokenizer(proper_transcript, return_tensors="pt", padding=True, truncation=True)
196
+ target_ids = tokens.input_ids
197
+
198
+ log_probs = torch.log_softmax(logits, dim=-1)
199
+ input_lengths = torch.tensor([log_probs.shape[1]], dtype=torch.long)
200
+ target_lengths = torch.tensor([target_ids.shape[1]], dtype=torch.long)
201
+
202
+ ctc_loss = CTCLoss(blank=self.base_model.config.pad_token_id)
203
+ loss = ctc_loss(log_probs.transpose(0, 1), targets=target_ids,
204
+ input_lengths=input_lengths, target_lengths=target_lengths)
205
+
206
+ # Find stutter timestamps using large model
207
+ input_audio, sample_rate = librosa.load(audio_file, sr=16000)
208
+
209
+ if sample_rate != 16000:
210
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
211
+ input_audio = resampler(torch.from_numpy(input_audio)).numpy()
212
+
213
+ input_features = self.large_processor(input_audio, return_tensors='pt').input_values
214
+
215
+ with torch.no_grad():
216
+ logits = self.large_model(input_features).logits
217
+
218
+ predicted_ids = torch.argmax(logits, dim=-1)
219
+ blank_token_id = self.large_model.config.pad_token_id
220
+
221
+ # Extract timestamp ranges
222
+ stuttering_seconds = []
223
+ prev_token = blank_token_id
224
+ frame_shift = 0.02 # 20ms per frame
225
+ audio_duration = len(input_audio) / sample_rate
226
+
227
+ for frame_idx, token_id in enumerate(predicted_ids[0]):
228
+ if token_id != blank_token_id and token_id != prev_token:
229
+ start_frame = frame_idx
230
+ end_frame = frame_idx + token_id.item() - 1
231
+ start_second = min(start_frame * frame_shift, audio_duration)
232
+ end_second = min(end_frame * frame_shift, audio_duration)
233
+
234
+ # Detect prolongations (duration > 0.4s)
235
+ if end_second - start_second > 0.4:
236
+ stuttering_seconds.append((round(start_second, 2), round(end_second, 2)))
237
+
238
+ prev_token = token_id
239
+
240
+ return round(loss.item(), 2), stuttering_seconds
241
+
242
+ except Exception as e:
243
+ logger.error(f"Timestamp calculation failed: {e}")
244
+ return 0.0, []
245
+
246
+
247
+ def find_max_common_characters(self, transcription1: str, transcript2: str) -> str:
248
+ """Longest Common Subsequence algorithm"""
249
+ m, n = len(transcription1), len(transcript2)
250
+ lcs_matrix = [[0] * (n + 1) for _ in range(m + 1)]
251
+
252
+ for i in range(1, m + 1):
253
+ for j in range(1, n + 1):
254
+ if transcription1[i - 1] == transcript2[j - 1]:
255
+ lcs_matrix[i][j] = lcs_matrix[i - 1][j - 1] + 1
256
+ else:
257
+ lcs_matrix[i][j] = max(lcs_matrix[i - 1][j], lcs_matrix[i][j - 1])
258
+
259
+ # Backtrack to find LCS
260
+ lcs_characters = []
261
+ i, j = m, n
262
+ while i > 0 and j > 0:
263
+ if transcription1[i - 1] == transcript2[j - 1]:
264
+ lcs_characters.append(transcription1[i - 1])
265
+ i -= 1
266
+ j -= 1
267
+ elif lcs_matrix[i - 1][j] > lcs_matrix[i][j - 1]:
268
+ i -= 1
269
+ else:
270
+ j -= 1
271
+
272
+ lcs_characters.reverse()
273
+ return ''.join(lcs_characters)
274
+
275
+
276
+ def find_sequences_not_in_common(self, transcription1: str, proper_transcript: str) -> List[str]:
277
+ """Find stuttered character sequences"""
278
+ common_characters = self.find_max_common_characters(transcription1, proper_transcript)
279
+ sequences = []
280
+ sequence = ""
281
+ i, j = 0, 0
282
+
283
+ while i < len(transcription1) and j < len(common_characters):
284
+ if transcription1[i] == common_characters[j]:
285
+ if sequence:
286
+ sequences.append(sequence)
287
+ sequence = ""
288
+ i += 1
289
+ j += 1
290
+ else:
291
+ sequence += transcription1[i]
292
+ i += 1
293
+
294
+ if sequence:
295
+ sequences.append(sequence)
296
+
297
+ return sequences
298
+
299
+
300
+ def _calculate_total_duration(self, timestamps: List[Tuple[float, float]]) -> float:
301
+ """Calculate total stuttering duration"""
302
+ return sum(end - start for start, end in timestamps)
303
+
304
+
305
+ def _calculate_frequency(self, timestamps: List[Tuple[float, float]], audio_file: str) -> float:
306
+ """Calculate stutters per minute"""
307
+ try:
308
+ audio_duration = librosa.get_duration(path=audio_file)
309
+ if audio_duration > 0:
310
+ return (len(timestamps) / audio_duration) * 60
311
+ return 0.0
312
+ except:
313
+ return 0.0
314
+
315
+
316
+ def _determine_severity(self, mismatch_percentage: float) -> str:
317
+ """Determine severity level"""
318
+ if mismatch_percentage < 10:
319
+ return 'none'
320
+ elif mismatch_percentage < 25:
321
+ return 'mild'
322
+ elif mismatch_percentage < 50:
323
+ return 'moderate'
324
+ else:
325
+ return 'severe'
326
+
327
+
328
+ def _calculate_confidence(self, transcription_result: Dict, ctc_loss: float) -> float:
329
+ """Calculate confidence score for the analysis"""
330
+ # Lower mismatch and lower CTC loss = higher confidence
331
+ mismatch_factor = 1 - (transcription_result['mismatch_percentage'] / 100)
332
+ loss_factor = max(0, 1 - (ctc_loss / 10)) # Normalize loss
333
+ confidence = (mismatch_factor + loss_factor) / 2
334
+ return round(min(max(confidence, 0.0), 1.0), 2)
335
+
336
+
337
+ # diagnosis/ai_engine/model_loader.py
338
+ """Singleton pattern for model loading"""
339
+ _detector_instance = None
340
+
341
+ def get_stutter_detector():
342
+ """Get or create singleton StutterDetector instance"""
343
+ global _detector_instance
344
+ if _detector_instance is None:
345
+ _detector_instance = StutterDetector()
346
+ return _detector_instance