import os import io import re import tempfile # Removed heavy imports from top to speed up startup: # import torch # import numpy as np # import soundfile as sf from flask import Flask, request, jsonify, send_file, render_template from flask_cors import CORS from gtts import gTTS from gtts.tts import gTTSError # Removed top-level transformers import to lazy-load MMS: # from transformers import VitsModel, AutoTokenizer # Lazy MMS globals mms_model = None mms_tokenizer = None # Define a writable cache directory for Hugging Face models CACHE_DIR = os.environ.get("TRANSFORMERS_CACHE") def load_mms(): global mms_model, mms_tokenizer if mms_model and mms_tokenizer: return print("Loading Facebook MMS-TTS model for Amharic...") print(f"Using cache directory: {CACHE_DIR}") from transformers import VitsModel, AutoTokenizer mms_model_id = "facebook/mms-tts-amh" # Explicitly pass the cache_dir to from_pretrained mms_model = VitsModel.from_pretrained(mms_model_id, cache_dir=CACHE_DIR) mms_tokenizer = AutoTokenizer.from_pretrained(mms_model_id, cache_dir=CACHE_DIR) print("MMS-TTS model loaded successfully.") app = Flask(__name__, static_folder='static', template_folder='templates') CORS(app) @app.route('/') def index(): return render_template('index.html') # Health check @app.route('/health') def health(): return jsonify({ "ok": True, "mms_loaded": bool(mms_model and mms_tokenizer) }) @app.route('/api/tts', methods=['POST']) def text_to_speech(): data = request.get_json() if not data or 'text' not in data or not data['text'].strip(): return jsonify({"error": "Text is required."}), 400 text = data.get('text') model = data.get('model', 'gtts') speed = float(data.get('speed', 1.0)) print(f"--- Received TTS Request for model: {model} ---") try: if model == 'gtts': try: print("Attempting gTTS synthesis with default endpoint (tld='com')...") tts = gTTS(text=text, lang='am', slow=(speed < 1.0), lang_check=False) with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp: tmp_path = tmp.name try: tts.save(tmp_path) with open(tmp_path, 'rb') as f: data_bytes = f.read() finally: try: os.remove(tmp_path) except OSError: pass if not data_bytes: raise RuntimeError("gTTS produced empty audio stream") audio_fp = io.BytesIO(data_bytes) audio_fp.seek(0) print("Successfully generated audio with gTTS.") return send_file(audio_fp, mimetype='audio/mpeg') except gTTSError as ge: msg = ("gTTS failed using the default endpoint (Google TTS). " "Please try again later or use the MMS model.") print(f"gTTS gTTSError: {ge}") return jsonify({"error": msg, "details": str(ge)}), 502 except Exception as ge: msg = "gTTS failed unexpectedly on the default endpoint." print(f"gTTS unexpected error: {ge}") return jsonify({"error": msg, "details": str(ge)}), 502 elif model == 'mms': try: load_mms() except Exception as e: print(f"Failed to load MMS: {e}") return jsonify({"error": "MMS-TTS model is not available on the server.", "details": str(e)}), 500 print("Generating audio with MMS-TTS...") # Heavy imports only used here import torch import soundfile as sf # The transformers tokenizer will automatically use uroman if it's installed. # No explicit call is needed. if re.search(r"[^A-Za-z0-9\s\.,\?!;:'\"\-]", text): print("Text contains non-Roman characters. Relying on tokenizer's automatic romanization.") inputs = mms_tokenizer(text, return_tensors="pt") try: input_len = inputs["input_ids"].shape[-1] except Exception: input_len = 0 if input_len == 0: msg = ("MMS-TTS received text that tokenized to length 0. " "Install 'uroman' (Python >= 3.10) or provide romanized Latin text.") print(msg) return jsonify({"error": msg}), 400 with torch.no_grad(): output = mms_model(**inputs).waveform sampling_rate = mms_model.config.sampling_rate speech_waveform = output.cpu().numpy().squeeze() audio_fp = io.BytesIO() sf.write(audio_fp, speech_waveform, sampling_rate, format='WAV') audio_fp.seek(0) print("Successfully generated audio with MMS-TTS.") return send_file(audio_fp, mimetype='audio/wav') elif model in ['openai', 'azure']: return jsonify({"error": "The keys for this model have expired. Please use other models."}), 403 else: return jsonify({"error": f"The model '{model}' is not implemented yet."}), 501 except Exception as e: print(f"An error occurred: {e}") return jsonify({"error": f"An unexpected error occurred during TTS generation: {str(e)}"}), 500 if __name__ == '__main__': port = int(os.getenv('PORT', 7860)) app.run(debug=False, port=port, host='0.0.0.0')