KasunUoM commited on
Commit
793cc00
·
verified ·
1 Parent(s): 9d28de0

Script for inference

Browse files
Files changed (1) hide show
  1. inference_M2.py +44 -0
inference_M2.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, send_file, jsonify
2
+ from TTS.utils.synthesizer import Synthesizer
3
+ from romanizer import sinhala_to_roman
4
+ import io
5
+ from datetime import datetime
6
+ import torch
7
+
8
+ # Model paths
9
+ MODEL_PATH = "Roshan_270000.pth"
10
+ CONFIG_PATH = "Roshan_config.json"
11
+
12
+ # Init Flask app
13
+ app = Flask(__name__)
14
+
15
+ # Load model (use CUDA if available)
16
+ use_cuda = torch.cuda.is_available()
17
+ synth = Synthesizer(tts_checkpoint=MODEL_PATH, tts_config_path=CONFIG_PATH, use_cuda=use_cuda)
18
+
19
+ @app.route("/tts", methods=["POST"])
20
+ def tts():
21
+ """POST JSON: { "text": "<Sinhala text>" }"""
22
+ data = request.get_json()
23
+ sinhala_text = (data.get("text") or "").strip()
24
+ if not sinhala_text:
25
+ return jsonify({"error": "No text provided"}), 400
26
+
27
+ # Romanize Sinhala text
28
+ roman_text = sinhala_to_roman(sinhala_text)
29
+
30
+ # Generate audio
31
+ wav = synth.tts(roman_text)
32
+ out = io.BytesIO()
33
+ synth.save_wav(wav, out)
34
+ out.seek(0)
35
+
36
+ # Optional local save (timestamped)
37
+ filename = f"tts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
38
+ synth.save_wav(wav, filename)
39
+
40
+ # Return WAV directly
41
+ return send_file(out, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
42
+
43
+ if __name__ == "__main__":
44
+ app.run(host="0.0.0.0", port=8000)