cuhgrel commited on
Commit
b0bdaef
·
verified ·
1 Parent(s): 8491e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -7,10 +7,8 @@ from pydantic import BaseModel
7
  from fastapi.responses import StreamingResponse
8
 
9
  # --- Library Imports ---
10
- # For NeMo models
11
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
12
  from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
13
- # For Transformers MMS-TTS model
14
  from transformers import AutoTokenizer, AutoModelForTextToWaveform
15
 
16
  # Configure logging
@@ -26,12 +24,13 @@ app = FastAPI(
26
  # --- 2. Load Models on Startup ---
27
  models = {}
28
 
 
 
29
  @app.on_event("startup")
30
  def load_models():
31
  """Load all models into memory when the application starts."""
32
  logger.info("Loading models...")
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
-
35
  try:
36
  # --- NeMo Models ---
37
  logger.info("Loading HiFi-GAN vocoder...")
@@ -84,7 +83,6 @@ def synthesize_speech(request: TTSRequest):
84
 
85
  lang = request.language.lower()
86
 
87
- # Validate the requested language
88
  valid_langs = ['en', 'bikol', 'tgl']
89
  if lang not in valid_langs:
90
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
@@ -92,7 +90,6 @@ def synthesize_speech(request: TTSRequest):
92
  try:
93
  logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
94
 
95
- # --- Logic for NeMo Models (English, Bikol) ---
96
  if lang in ['en', 'bikol']:
97
  sample_rate = 22050
98
  spectrogram_generator = models[lang]
@@ -105,9 +102,8 @@ def synthesize_speech(request: TTSRequest):
105
 
106
  audio_numpy = audio.to('cpu').detach().numpy().squeeze()
107
 
108
- # --- Logic for Transformers Model (Tagalog) ---
109
  elif lang == 'tgl':
110
- sample_rate = 16000 # MMS-TTS default sample rate is 16kHz
111
  tokenizer = models['tgl_tokenizer']
112
  model = models['tgl_model']
113
 
@@ -117,7 +113,6 @@ def synthesize_speech(request: TTSRequest):
117
 
118
  audio_numpy = output.cpu().numpy().squeeze()
119
 
120
- # --- Prepare and return audio file ---
121
  buffer = io.BytesIO()
122
  sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
123
  buffer.seek(0)
@@ -134,8 +129,7 @@ def synthesize_speech(request: TTSRequest):
134
  # --- 5. Add a Root Endpoint for Health Check ---
135
  @app.get("/")
136
  def read_root():
137
- # Filter out tokenizer and non-spectrogram models for a cleaner list
138
- available_languages = [k for k in models.keys() if '_model' not in k and k != 'hifigan']
139
  return {
140
  "status": "Multilingual TTS Backend is running",
141
  "available_languages": available_languages,
 
7
  from fastapi.responses import StreamingResponse
8
 
9
  # --- Library Imports ---
 
10
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
11
  from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
 
12
  from transformers import AutoTokenizer, AutoModelForTextToWaveform
13
 
14
  # Configure logging
 
24
  # --- 2. Load Models on Startup ---
25
  models = {}
26
 
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
  @app.on_event("startup")
30
  def load_models():
31
  """Load all models into memory when the application starts."""
32
  logger.info("Loading models...")
33
+
 
34
  try:
35
  # --- NeMo Models ---
36
  logger.info("Loading HiFi-GAN vocoder...")
 
83
 
84
  lang = request.language.lower()
85
 
 
86
  valid_langs = ['en', 'bikol', 'tgl']
87
  if lang not in valid_langs:
88
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
 
90
  try:
91
  logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
92
 
 
93
  if lang in ['en', 'bikol']:
94
  sample_rate = 22050
95
  spectrogram_generator = models[lang]
 
102
 
103
  audio_numpy = audio.to('cpu').detach().numpy().squeeze()
104
 
 
105
  elif lang == 'tgl':
106
+ sample_rate = 16000
107
  tokenizer = models['tgl_tokenizer']
108
  model = models['tgl_model']
109
 
 
113
 
114
  audio_numpy = output.cpu().numpy().squeeze()
115
 
 
116
  buffer = io.BytesIO()
117
  sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
118
  buffer.seek(0)
 
129
  # --- 5. Add a Root Endpoint for Health Check ---
130
  @app.get("/")
131
  def read_root():
132
+ available_languages = ['en', 'bikol', 'tgl']
 
133
  return {
134
  "status": "Multilingual TTS Backend is running",
135
  "available_languages": available_languages,