Spaces:
Running
Running
updated the app.py to also load the facebook/mms/tgl model (#2)
Browse files- updated the app.py to also load the facebook/mms/tgl model (71bd426e5372e763be1b5a6ee4b203bba0cb226f)
app.py
CHANGED
|
@@ -4,9 +4,14 @@ import io
|
|
| 4 |
import logging
|
| 5 |
from fastapi import FastAPI, HTTPException, status
|
| 6 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
|
| 8 |
-
|
| 9 |
-
#
|
|
|
|
| 10 |
|
| 11 |
# Configure logging
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -14,8 +19,8 @@ logger = logging.getLogger(__name__)
|
|
| 14 |
|
| 15 |
# --- 1. Initialize FastAPI App ---
|
| 16 |
app = FastAPI(
|
| 17 |
-
title="
|
| 18 |
-
description="A backend service to convert text to speech in English and
|
| 19 |
)
|
| 20 |
|
| 21 |
# --- 2. Load Models on Startup ---
|
|
@@ -23,44 +28,50 @@ models = {}
|
|
| 23 |
|
| 24 |
@app.on_event("startup")
|
| 25 |
def load_models():
|
| 26 |
-
"""Load all
|
| 27 |
logger.info("Loading models...")
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
|
| 30 |
try:
|
| 31 |
-
#
|
| 32 |
logger.info("Loading HiFi-GAN vocoder...")
|
| 33 |
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
|
| 34 |
models['hifigan'].eval()
|
| 35 |
-
logger.info("HiFi-GAN loaded successfully")
|
| 36 |
|
| 37 |
-
# Load the English Spectrogram Generator
|
| 38 |
logger.info("Loading English FastPitch model...")
|
| 39 |
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
|
| 40 |
models['en'].eval()
|
| 41 |
-
logger.info("English model loaded successfully")
|
| 42 |
|
| 43 |
-
# Load the CORRECTED Bikol Spectrogram Generator
|
| 44 |
logger.info("Loading Bikol FastPitch model...")
|
| 45 |
-
# This is the only line needed now. Replace the filename with your new .nemo file.
|
| 46 |
models['bikol'] = FastPitchModel.restore_from("models/fastpitch_bikol_corrected.nemo").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
models['bikol'].eval()
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
except Exception as e:
|
| 51 |
logger.error(f"FATAL: Could not load models. Error: {e}")
|
| 52 |
import traceback
|
| 53 |
traceback.print_exc()
|
| 54 |
-
|
| 55 |
-
# raise e
|
| 56 |
|
| 57 |
-
logger.info("Model loading complete.
|
| 58 |
|
| 59 |
|
| 60 |
# --- 3. Define API Request and Response Models ---
|
| 61 |
class TTSRequest(BaseModel):
|
| 62 |
text: str
|
| 63 |
-
language: str # Should be 'en' or '
|
| 64 |
|
| 65 |
# --- 4. Define the TTS API Endpoint ---
|
| 66 |
@app.post("/synthesize/")
|
|
@@ -68,69 +79,50 @@ def synthesize_speech(request: TTSRequest):
|
|
| 68 |
"""
|
| 69 |
Generates speech from text using the selected language model.
|
| 70 |
"""
|
| 71 |
-
if not models
|
| 72 |
-
raise HTTPException(
|
| 73 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 74 |
-
detail="Models are not loaded yet. Please try again in a moment."
|
| 75 |
-
)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
if request.language not in ['en', 'bikol']:
|
| 79 |
-
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid language specified. Use 'en' or 'bikol'.")
|
| 80 |
-
if request.language not in models:
|
| 81 |
-
available = [k for k in models.keys() if k != 'hifigan']
|
| 82 |
-
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"The '{request.language}' model is not available. Available languages: {', '.join(available)}")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
try:
|
| 85 |
-
|
| 86 |
-
vocoder = models['hifigan']
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# --- DEBUG STEP 2: Check the generated spectrogram ---
|
| 98 |
-
spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
|
| 99 |
-
if spectrogram is not None:
|
| 100 |
-
logger.info(f"2. Spectrogram generated with shape: {spectrogram.shape}")
|
| 101 |
-
logger.info(f" Spectrogram stats: min={spectrogram.min()}, max={spectrogram.max()}, mean={spectrogram.mean()}")
|
| 102 |
-
else:
|
| 103 |
-
logger.error("2. FAILED: Spectrogram is None!")
|
| 104 |
-
|
| 105 |
-
# --- DEBUG STEP 3: Check the generated audio waveform ---
|
| 106 |
-
if spectrogram is not None:
|
| 107 |
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# --- Prepare and return audio file ---
|
| 120 |
-
audio_numpy = audio.to('cpu').detach().numpy()
|
| 121 |
-
|
| 122 |
-
logger.info(f"4. Successfully converted to NumPy array.")
|
| 123 |
-
|
| 124 |
-
if len(audio_numpy.shape) > 1:
|
| 125 |
-
audio_numpy = audio_numpy.squeeze()
|
| 126 |
-
|
| 127 |
buffer = io.BytesIO()
|
| 128 |
-
sf.write(buffer, audio_numpy, samplerate=
|
| 129 |
buffer.seek(0)
|
| 130 |
|
| 131 |
logger.info(f"--- SYNTHESIS COMPLETE ---")
|
| 132 |
-
|
| 133 |
-
from fastapi.responses import StreamingResponse
|
| 134 |
return StreamingResponse(buffer, media_type="audio/wav")
|
| 135 |
|
| 136 |
except Exception as e:
|
|
@@ -142,11 +134,12 @@ def synthesize_speech(request: TTSRequest):
|
|
| 142 |
# --- 5. Add a Root Endpoint for Health Check ---
|
| 143 |
@app.get("/")
|
| 144 |
def read_root():
|
| 145 |
-
|
|
|
|
| 146 |
return {
|
| 147 |
-
"status": "
|
| 148 |
-
"available_languages":
|
| 149 |
-
"device":
|
| 150 |
}
|
| 151 |
|
| 152 |
# --- 6. Add Model Status Endpoint ---
|
|
@@ -155,7 +148,8 @@ def get_status():
|
|
| 155 |
"""Get the status of all loaded models."""
|
| 156 |
return {
|
| 157 |
"models_loaded": list(models.keys()),
|
| 158 |
-
"device":
|
| 159 |
"english_available": 'en' in models,
|
| 160 |
-
"bikol_available": 'bikol' in models
|
|
|
|
| 161 |
}
|
|
|
|
| 4 |
import logging
|
| 5 |
from fastapi import FastAPI, HTTPException, status
|
| 6 |
from pydantic import BaseModel
|
| 7 |
+
from fastapi.responses import StreamingResponse
|
| 8 |
+
|
| 9 |
+
# --- Library Imports ---
|
| 10 |
+
# For NeMo models
|
| 11 |
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
|
| 12 |
+
from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
|
| 13 |
+
# For Transformers MMS-TTS model
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForTextToWaveform
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 19 |
|
| 20 |
# --- 1. Initialize FastAPI App ---
|
| 21 |
app = FastAPI(
|
| 22 |
+
title="Multilingual TTS API",
|
| 23 |
+
description="A backend service to convert text to speech in English, Bikol, and Tagalog.",
|
| 24 |
)
|
| 25 |
|
| 26 |
# --- 2. Load Models on Startup ---
|
|
|
|
| 28 |
|
| 29 |
@app.on_event("startup")
|
| 30 |
def load_models():
|
| 31 |
+
"""Load all models into memory when the application starts."""
|
| 32 |
logger.info("Loading models...")
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
|
| 35 |
try:
|
| 36 |
+
# --- NeMo Models ---
|
| 37 |
logger.info("Loading HiFi-GAN vocoder...")
|
| 38 |
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
|
| 39 |
models['hifigan'].eval()
|
|
|
|
| 40 |
|
|
|
|
| 41 |
logger.info("Loading English FastPitch model...")
|
| 42 |
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
|
| 43 |
models['en'].eval()
|
|
|
|
| 44 |
|
|
|
|
| 45 |
logger.info("Loading Bikol FastPitch model...")
|
|
|
|
| 46 |
models['bikol'] = FastPitchModel.restore_from("models/fastpitch_bikol_corrected.nemo").to(device)
|
| 47 |
+
|
| 48 |
+
logger.info("Overriding Bikol model tokenizer...")
|
| 49 |
+
BIKOL_CHARS = [
|
| 50 |
+
' ', '!', ',', '-', '.', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
|
| 51 |
+
'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
|
| 52 |
+
'y', 'z', 'à', 'á', 'â', 'é', 'ì', 'í', 'î', 'ñ', 'ò', 'ó', 'ô', 'ú', '’'
|
| 53 |
+
]
|
| 54 |
+
models['bikol'].tokenizer = BaseCharsTokenizer(chars=BIKOL_CHARS)
|
| 55 |
models['bikol'].eval()
|
| 56 |
+
|
| 57 |
+
# --- Transformers MMS-TTS Model ---
|
| 58 |
+
logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
|
| 59 |
+
models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
|
| 60 |
+
models['tgl_model'] = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-tgl").to(device)
|
| 61 |
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"FATAL: Could not load models. Error: {e}")
|
| 64 |
import traceback
|
| 65 |
traceback.print_exc()
|
| 66 |
+
raise e
|
|
|
|
| 67 |
|
| 68 |
+
logger.info("Model loading complete.")
|
| 69 |
|
| 70 |
|
| 71 |
# --- 3. Define API Request and Response Models ---
|
| 72 |
class TTSRequest(BaseModel):
|
| 73 |
text: str
|
| 74 |
+
language: str # Should be 'en', 'bikol', or 'tgl'
|
| 75 |
|
| 76 |
# --- 4. Define the TTS API Endpoint ---
|
| 77 |
@app.post("/synthesize/")
|
|
|
|
| 79 |
"""
|
| 80 |
Generates speech from text using the selected language model.
|
| 81 |
"""
|
| 82 |
+
if not models:
|
| 83 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.")
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
lang = request.language.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# Validate the requested language
|
| 88 |
+
valid_langs = ['en', 'bikol', 'tgl']
|
| 89 |
+
if lang not in valid_langs:
|
| 90 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
|
| 91 |
+
|
| 92 |
try:
|
| 93 |
+
logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
|
|
|
|
| 94 |
|
| 95 |
+
# --- Logic for NeMo Models (English, Bikol) ---
|
| 96 |
+
if lang in ['en', 'bikol']:
|
| 97 |
+
sample_rate = 22050
|
| 98 |
+
spectrogram_generator = models[lang]
|
| 99 |
+
vocoder = models['hifigan']
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
parsed = spectrogram_generator.parse(request.text)
|
| 103 |
+
spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
|
| 105 |
+
|
| 106 |
+
audio_numpy = audio.to('cpu').detach().numpy().squeeze()
|
| 107 |
+
|
| 108 |
+
# --- Logic for Transformers Model (Tagalog) ---
|
| 109 |
+
elif lang == 'tgl':
|
| 110 |
+
sample_rate = 16000 # MMS-TTS default sample rate is 16kHz
|
| 111 |
+
tokenizer = models['tgl_tokenizer']
|
| 112 |
+
model = models['tgl_model']
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
inputs = tokenizer(request.text, return_tensors="pt").to(device)
|
| 116 |
+
output = model.generate(**inputs)
|
| 117 |
+
|
| 118 |
+
audio_numpy = output.cpu().numpy().squeeze()
|
| 119 |
|
| 120 |
# --- Prepare and return audio file ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
buffer = io.BytesIO()
|
| 122 |
+
sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
|
| 123 |
buffer.seek(0)
|
| 124 |
|
| 125 |
logger.info(f"--- SYNTHESIS COMPLETE ---")
|
|
|
|
|
|
|
| 126 |
return StreamingResponse(buffer, media_type="audio/wav")
|
| 127 |
|
| 128 |
except Exception as e:
|
|
|
|
| 134 |
# --- 5. Add a Root Endpoint for Health Check ---
|
| 135 |
@app.get("/")
|
| 136 |
def read_root():
|
| 137 |
+
# Filter out tokenizer and non-spectrogram models for a cleaner list
|
| 138 |
+
available_languages = [k for k in models.keys() if '_model' not in k and k != 'hifigan']
|
| 139 |
return {
|
| 140 |
+
"status": "Multilingual TTS Backend is running",
|
| 141 |
+
"available_languages": available_languages,
|
| 142 |
+
"device": device
|
| 143 |
}
|
| 144 |
|
| 145 |
# --- 6. Add Model Status Endpoint ---
|
|
|
|
| 148 |
"""Get the status of all loaded models."""
|
| 149 |
return {
|
| 150 |
"models_loaded": list(models.keys()),
|
| 151 |
+
"device": device,
|
| 152 |
"english_available": 'en' in models,
|
| 153 |
+
"bikol_available": 'bikol' in models,
|
| 154 |
+
"tagalog_available": 'tgl_model' in models,
|
| 155 |
}
|