cuhgrel commited on
Commit
8491e9b
·
verified ·
1 Parent(s): 6f7c9d8

updated the app.py to also load the facebook/mms/tgl model (#2)

Browse files

- updated the app.py to also load the facebook/mms/tgl model (71bd426e5372e763be1b5a6ee4b203bba0cb226f)

Files changed (1) hide show
  1. app.py +68 -74
app.py CHANGED
@@ -4,9 +4,14 @@ import io
4
  import logging
5
  from fastapi import FastAPI, HTTPException, status
6
  from pydantic import BaseModel
 
 
 
 
7
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
8
- # Omegaconf is no longer needed here since we aren't creating overrides
9
- # from omegaconf import OmegaConf, open_dict
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
@@ -14,8 +19,8 @@ logger = logging.getLogger(__name__)
14
 
15
  # --- 1. Initialize FastAPI App ---
16
  app = FastAPI(
17
- title="NVIDIA NeMo TTS API",
18
- description="A backend service to convert text to speech in English and Bikol.",
19
  )
20
 
21
  # --- 2. Load Models on Startup ---
@@ -23,44 +28,50 @@ models = {}
23
 
24
  @app.on_event("startup")
25
  def load_models():
26
- """Load all NeMo models into memory when the application starts."""
27
  logger.info("Loading models...")
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  try:
31
- # Load the shared HiFi-GAN Vocoder
32
  logger.info("Loading HiFi-GAN vocoder...")
33
  models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
34
  models['hifigan'].eval()
35
- logger.info("HiFi-GAN loaded successfully")
36
 
37
- # Load the English Spectrogram Generator
38
  logger.info("Loading English FastPitch model...")
39
  models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
40
  models['en'].eval()
41
- logger.info("English model loaded successfully")
42
 
43
- # Load the CORRECTED Bikol Spectrogram Generator
44
  logger.info("Loading Bikol FastPitch model...")
45
- # This is the only line needed now. Replace the filename with your new .nemo file.
46
  models['bikol'] = FastPitchModel.restore_from("models/fastpitch_bikol_corrected.nemo").to(device)
 
 
 
 
 
 
 
 
47
  models['bikol'].eval()
48
- logger.info("Bikol model loaded successfully")
 
 
 
 
49
 
50
  except Exception as e:
51
  logger.error(f"FATAL: Could not load models. Error: {e}")
52
  import traceback
53
  traceback.print_exc()
54
- # You might want the app to fail completely if models don't load
55
- # raise e
56
 
57
- logger.info("Model loading complete. Available models: " + ", ".join(models.keys()))
58
 
59
 
60
  # --- 3. Define API Request and Response Models ---
61
  class TTSRequest(BaseModel):
62
  text: str
63
- language: str # Should be 'en' or 'bikol'
64
 
65
  # --- 4. Define the TTS API Endpoint ---
66
  @app.post("/synthesize/")
@@ -68,69 +79,50 @@ def synthesize_speech(request: TTSRequest):
68
  """
69
  Generates speech from text using the selected language model.
70
  """
71
- if not models or 'hifigan' not in models:
72
- raise HTTPException(
73
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
74
- detail="Models are not loaded yet. Please try again in a moment."
75
- )
76
 
77
- # ... (the validation code remains the same) ...
78
- if request.language not in ['en', 'bikol']:
79
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid language specified. Use 'en' or 'bikol'.")
80
- if request.language not in models:
81
- available = [k for k in models.keys() if k != 'hifigan']
82
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"The '{request.language}' model is not available. Available languages: {', '.join(available)}")
83
 
 
 
 
 
 
84
  try:
85
- spectrogram_generator = models[request.language]
86
- vocoder = models['hifigan']
87
 
88
- logger.info(f"--- STARTING SYNTHESIS FOR '{request.text}' ---")
89
-
90
- audio = None # Define audio here to ensure it exists
91
- with torch.no_grad():
92
- # --- DEBUG STEP 1: Check the parsed tokens ---
93
- parsed = spectrogram_generator.parse(request.text)
94
- logger.info(f"1. Parsed tokens shape: {parsed.shape}")
95
- logger.info(f" Parsed tokens content: {parsed}")
96
-
97
- # --- DEBUG STEP 2: Check the generated spectrogram ---
98
- spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
99
- if spectrogram is not None:
100
- logger.info(f"2. Spectrogram generated with shape: {spectrogram.shape}")
101
- logger.info(f" Spectrogram stats: min={spectrogram.min()}, max={spectrogram.max()}, mean={spectrogram.mean()}")
102
- else:
103
- logger.error("2. FAILED: Spectrogram is None!")
104
-
105
- # --- DEBUG STEP 3: Check the generated audio waveform ---
106
- if spectrogram is not None:
107
  audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
108
- if audio is not None:
109
- logger.info(f"3. Audio generated with shape: {audio.shape}")
110
- logger.info(f" Audio stats: min={audio.min()}, max={audio.max()}, mean={audio.mean()}")
111
- else:
112
- logger.error("3. FAILED: Audio is None!")
113
-
114
- # If audio generation failed, we can't proceed
115
- if audio is None:
116
- logger.error("Synthesis failed, audio tensor is None.")
117
- raise HTTPException(status_code=500, detail="Audio generation failed internally, resulting in None.")
 
 
 
 
118
 
119
  # --- Prepare and return audio file ---
120
- audio_numpy = audio.to('cpu').detach().numpy()
121
-
122
- logger.info(f"4. Successfully converted to NumPy array.")
123
-
124
- if len(audio_numpy.shape) > 1:
125
- audio_numpy = audio_numpy.squeeze()
126
-
127
  buffer = io.BytesIO()
128
- sf.write(buffer, audio_numpy, samplerate=22050, format='WAV')
129
  buffer.seek(0)
130
 
131
  logger.info(f"--- SYNTHESIS COMPLETE ---")
132
-
133
- from fastapi.responses import StreamingResponse
134
  return StreamingResponse(buffer, media_type="audio/wav")
135
 
136
  except Exception as e:
@@ -142,11 +134,12 @@ def synthesize_speech(request: TTSRequest):
142
  # --- 5. Add a Root Endpoint for Health Check ---
143
  @app.get("/")
144
  def read_root():
145
- available_models = [k for k in models.keys() if k != 'hifigan']
 
146
  return {
147
- "status": "NeMo TTS Backend is running",
148
- "available_languages": available_models,
149
- "device": "cuda" if torch.cuda.is_available() else "cpu"
150
  }
151
 
152
  # --- 6. Add Model Status Endpoint ---
@@ -155,7 +148,8 @@ def get_status():
155
  """Get the status of all loaded models."""
156
  return {
157
  "models_loaded": list(models.keys()),
158
- "device": "cuda" if torch.cuda.is_available() else "cpu",
159
  "english_available": 'en' in models,
160
- "bikol_available": 'bikol' in models
 
161
  }
 
4
  import logging
5
  from fastapi import FastAPI, HTTPException, status
6
  from pydantic import BaseModel
7
+ from fastapi.responses import StreamingResponse
8
+
9
+ # --- Library Imports ---
10
+ # For NeMo models
11
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
12
+ from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
13
+ # For Transformers MMS-TTS model
14
+ from transformers import AutoTokenizer, AutoModelForTextToWaveform
15
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
 
19
 
20
  # --- 1. Initialize FastAPI App ---
21
  app = FastAPI(
22
+ title="Multilingual TTS API",
23
+ description="A backend service to convert text to speech in English, Bikol, and Tagalog.",
24
  )
25
 
26
  # --- 2. Load Models on Startup ---
 
28
 
29
  @app.on_event("startup")
30
  def load_models():
31
+ """Load all models into memory when the application starts."""
32
  logger.info("Loading models...")
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
  try:
36
+ # --- NeMo Models ---
37
  logger.info("Loading HiFi-GAN vocoder...")
38
  models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
39
  models['hifigan'].eval()
 
40
 
 
41
  logger.info("Loading English FastPitch model...")
42
  models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
43
  models['en'].eval()
 
44
 
 
45
  logger.info("Loading Bikol FastPitch model...")
 
46
  models['bikol'] = FastPitchModel.restore_from("models/fastpitch_bikol_corrected.nemo").to(device)
47
+
48
+ logger.info("Overriding Bikol model tokenizer...")
49
+ BIKOL_CHARS = [
50
+ ' ', '!', ',', '-', '.', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
51
+ 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
52
+ 'y', 'z', 'à', 'á', 'â', 'é', 'ì', 'í', 'î', 'ñ', 'ò', 'ó', 'ô', 'ú', '’'
53
+ ]
54
+ models['bikol'].tokenizer = BaseCharsTokenizer(chars=BIKOL_CHARS)
55
  models['bikol'].eval()
56
+
57
+ # --- Transformers MMS-TTS Model ---
58
+ logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
59
+ models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
60
+ models['tgl_model'] = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-tgl").to(device)
61
 
62
  except Exception as e:
63
  logger.error(f"FATAL: Could not load models. Error: {e}")
64
  import traceback
65
  traceback.print_exc()
66
+ raise e
 
67
 
68
+ logger.info("Model loading complete.")
69
 
70
 
71
  # --- 3. Define API Request and Response Models ---
72
  class TTSRequest(BaseModel):
73
  text: str
74
+ language: str # Should be 'en', 'bikol', or 'tgl'
75
 
76
  # --- 4. Define the TTS API Endpoint ---
77
  @app.post("/synthesize/")
 
79
  """
80
  Generates speech from text using the selected language model.
81
  """
82
+ if not models:
83
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.")
 
 
 
84
 
85
+ lang = request.language.lower()
 
 
 
 
 
86
 
87
+ # Validate the requested language
88
+ valid_langs = ['en', 'bikol', 'tgl']
89
+ if lang not in valid_langs:
90
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
91
+
92
  try:
93
+ logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
 
94
 
95
+ # --- Logic for NeMo Models (English, Bikol) ---
96
+ if lang in ['en', 'bikol']:
97
+ sample_rate = 22050
98
+ spectrogram_generator = models[lang]
99
+ vocoder = models['hifigan']
100
+
101
+ with torch.no_grad():
102
+ parsed = spectrogram_generator.parse(request.text)
103
+ spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
 
 
 
 
 
 
 
 
 
 
104
  audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
105
+
106
+ audio_numpy = audio.to('cpu').detach().numpy().squeeze()
107
+
108
+ # --- Logic for Transformers Model (Tagalog) ---
109
+ elif lang == 'tgl':
110
+ sample_rate = 16000 # MMS-TTS default sample rate is 16kHz
111
+ tokenizer = models['tgl_tokenizer']
112
+ model = models['tgl_model']
113
+
114
+ with torch.no_grad():
115
+ inputs = tokenizer(request.text, return_tensors="pt").to(device)
116
+ output = model.generate(**inputs)
117
+
118
+ audio_numpy = output.cpu().numpy().squeeze()
119
 
120
  # --- Prepare and return audio file ---
 
 
 
 
 
 
 
121
  buffer = io.BytesIO()
122
+ sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
123
  buffer.seek(0)
124
 
125
  logger.info(f"--- SYNTHESIS COMPLETE ---")
 
 
126
  return StreamingResponse(buffer, media_type="audio/wav")
127
 
128
  except Exception as e:
 
134
  # --- 5. Add a Root Endpoint for Health Check ---
135
  @app.get("/")
136
  def read_root():
137
+ # Filter out tokenizer and non-spectrogram models for a cleaner list
138
+ available_languages = [k for k in models.keys() if '_model' not in k and k != 'hifigan']
139
  return {
140
+ "status": "Multilingual TTS Backend is running",
141
+ "available_languages": available_languages,
142
+ "device": device
143
  }
144
 
145
  # --- 6. Add Model Status Endpoint ---
 
148
  """Get the status of all loaded models."""
149
  return {
150
  "models_loaded": list(models.keys()),
151
+ "device": device,
152
  "english_available": 'en' in models,
153
+ "bikol_available": 'bikol' in models,
154
+ "tagalog_available": 'tgl_model' in models,
155
  }