ASLP-lab commited on
Commit
010341e
·
verified ·
1 Parent(s): 9f7e23b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. app.py +533 -0
  3. bigvgan/__init__.py +0 -0
  4. bigvgan/activations.py +126 -0
  5. bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  6. bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  7. bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  8. bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  9. bigvgan/alias_free_activation/cuda/compat.h +29 -0
  10. bigvgan/alias_free_activation/cuda/load.py +86 -0
  11. bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  12. bigvgan/alias_free_activation/torch/__init__.py +6 -0
  13. bigvgan/alias_free_activation/torch/act.py +30 -0
  14. bigvgan/alias_free_activation/torch/filter.py +101 -0
  15. bigvgan/alias_free_activation/torch/resample.py +58 -0
  16. bigvgan/env.py +18 -0
  17. bigvgan/model.py +545 -0
  18. bigvgan/utils.py +59 -0
  19. diffrhythm2/__init__.py +0 -0
  20. diffrhythm2/backbones/__init__.py +0 -0
  21. diffrhythm2/backbones/dit.py +222 -0
  22. diffrhythm2/backbones/flex_attention.py +237 -0
  23. diffrhythm2/backbones/llama_attention.py +451 -0
  24. diffrhythm2/backbones/llama_nar.py +140 -0
  25. diffrhythm2/cache_utils.py +154 -0
  26. diffrhythm2/cfm.py +425 -0
  27. g2p/__init__.py +0 -0
  28. g2p/g2p/__init__.py +87 -0
  29. g2p/g2p/chinese_model_g2p.py +213 -0
  30. g2p/g2p/cleaners.py +31 -0
  31. g2p/g2p/english.py +202 -0
  32. g2p/g2p/french.py +149 -0
  33. g2p/g2p/german.py +94 -0
  34. g2p/g2p/japanese.py +816 -0
  35. g2p/g2p/korean.py +81 -0
  36. g2p/g2p/mandarin.py +597 -0
  37. g2p/g2p/text_tokenizers.py +84 -0
  38. g2p/g2p/vocab.json +372 -0
  39. g2p/g2p_generation.py +134 -0
  40. g2p/language_segmentation/LangSegment.py +865 -0
  41. g2p/language_segmentation/__init__.py +9 -0
  42. g2p/language_segmentation/utils/__init__.py +0 -0
  43. g2p/language_segmentation/utils/num.py +327 -0
  44. g2p/sources/bpmf_2_pinyin.txt +41 -0
  45. g2p/sources/chinese_lexicon.txt +3 -0
  46. g2p/sources/g2p_chinese_model/config.json +819 -0
  47. g2p/sources/g2p_chinese_model/poly_bert_model.onnx +3 -0
  48. g2p/sources/g2p_chinese_model/polychar.txt +159 -0
  49. g2p/sources/g2p_chinese_model/polydict.json +393 -0
  50. g2p/sources/g2p_chinese_model/polydict_r.json +393 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import json
7
+ import os
8
+ import random
9
+ import numpy as np
10
+ import io
11
+ import pydub
12
+ import base64
13
+ from muq import MuQMuLan
14
+ from diffrhythm2.cfm import CFM
15
+ from diffrhythm2.backbones.dit import DiT
16
+ from bigvgan.model import Generator
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ STRUCT_INFO = {
20
+ "[start]": 500,
21
+ "[end]": 501,
22
+ "[intro]": 502,
23
+ "[verse]": 503,
24
+ "[chorus]": 504,
25
+ "[outro]": 505,
26
+ "[inst]": 506,
27
+ "[solo]": 507,
28
+ "[bridge]": 508,
29
+ "[hook]": 509,
30
+ "[break]": 510,
31
+ "[stop]": 511,
32
+ "[space]": 512
33
+ }
34
+
35
+ class CNENTokenizer():
36
+ def __init__(self):
37
+ curr_path = os.path.abspath(__file__)
38
+ vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json")
39
+ with open(vocab_path, 'r') as file:
40
+ self.phone2id:dict = json.load(file)['vocab']
41
+ self.id2phone = {v:k for (k, v) in self.phone2id.items()}
42
+ from g2p.g2p_generation import chn_eng_g2p
43
+ self.tokenizer = chn_eng_g2p
44
+ def encode(self, text):
45
+ phone, token = self.tokenizer(text)
46
+ token = [x+1 for x in token]
47
+ return token
48
+ def decode(self, token):
49
+ return "|".join([self.id2phone[x-1] for x in token])
50
+
51
+ def prepare_model(repo_id, device, dtype):
52
+ diffrhythm2_ckpt_path = hf_hub_download(
53
+ repo_id=repo_id,
54
+ filename="model.safetensors",
55
+ local_dir="./ckpt",
56
+ local_files_only=False,
57
+ )
58
+ diffrhythm2_config_path = hf_hub_download(
59
+ repo_id=repo_id,
60
+ filename="model.json",
61
+ local_dir="./ckpt",
62
+ local_files_only=False,
63
+ )
64
+ with open(diffrhythm2_config_path) as f:
65
+ model_config = json.load(f)
66
+
67
+ model_config['use_flex_attn'] = False
68
+ diffrhythm2 = CFM(
69
+ transformer=DiT(
70
+ **model_config
71
+ ),
72
+ num_channels=model_config['mel_dim'],
73
+ block_size=model_config['block_size'],
74
+ )
75
+
76
+ total_params = sum(p.numel() for p in diffrhythm2.parameters())
77
+
78
+ diffrhythm2 = diffrhythm2.to(device).to(dtype)
79
+ if diffrhythm2_ckpt_path.endswith('.safetensors'):
80
+ from safetensors.torch import load_file
81
+ ckpt = load_file(diffrhythm2_ckpt_path)
82
+ else:
83
+ ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
84
+ diffrhythm2.load_state_dict(ckpt)
85
+ print(f"Total params: {total_params:,}")
86
+
87
+ # load Mulan
88
+ mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype)
89
+
90
+ # load frontend
91
+ lrc_tokenizer = CNENTokenizer()
92
+
93
+ # load decoder
94
+ decoder_ckpt_path = hf_hub_download(
95
+ repo_id=repo_id,
96
+ filename="decoder.bin",
97
+ local_dir="./ckpt",
98
+ local_files_only=False,
99
+ )
100
+ decoder_config_path = hf_hub_download(
101
+ repo_id=repo_id,
102
+ filename="decoder.json",
103
+ local_dir="./ckpt",
104
+ local_files_only=False,
105
+ )
106
+ decoder = Generator(decoder_config_path, decoder_ckpt_path)
107
+ decoder = decoder.to(device).to(dtype)
108
+
109
+ return diffrhythm2, mulan, lrc_tokenizer, decoder
110
+
111
+ def parse_lyrics(lyrics: str):
112
+ lyrics_with_time = []
113
+ lyrics = lyrics.split("\n")
114
+ for line in lyrics:
115
+ struct_idx = STRUCT_INFO.get(line, None)
116
+ if struct_idx is not None:
117
+ lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
118
+ else:
119
+ tokens = lrc_tokenizer.encode(line.strip())
120
+ tokens = tokens + [STRUCT_INFO['[stop]']]
121
+ lyrics_with_time.append(tokens)
122
+ return lyrics_with_time
123
+
124
+ def get_audio_prompt(model, audio_file, device, dtype):
125
+ prompt_wav, sr = torchaudio.load(audio_file)
126
+ prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
127
+ if prompt_wav.shape[1] > 24000 * 10:
128
+ start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
129
+ prompt_wav = prompt_wav[:, start:start+24000*10]
130
+ prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
131
+ with torch.no_grad():
132
+ style_prompt_embed = model(wavs = prompt_wav)
133
+ return style_prompt_embed.squeeze(0)
134
+
135
+ def get_text_prompt(model, text, device, dtype):
136
+ with torch.no_grad():
137
+ style_prompt_embed = model(texts = [text])
138
+ return style_prompt_embed.squeeze(0)
139
+
140
+ def make_fake_stereo(audio, sampling_rate):
141
+ left_channel = audio
142
+ right_channel = audio.clone()
143
+ right_channel = right_channel * 0.8
144
+ delay_samples = int(0.01 * sampling_rate)
145
+ right_channel = torch.roll(right_channel, delay_samples)
146
+ right_channel[:,:delay_samples] = 0
147
+ # stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
148
+ stereo_audio = torch.cat([left_channel, right_channel], dim=0)
149
+
150
+ return stereo_audio
151
+
152
+ def inference(
153
+ model,
154
+ decoder,
155
+ text,
156
+ style_prompt,
157
+ duration,
158
+ cfg_strength=1.0,
159
+ sample_steps=32,
160
+ fake_stereo=True,
161
+ odeint_method='euler',
162
+ file_type="wav"
163
+ ):
164
+ with torch.inference_mode():
165
+ latent = model.sample_block_cache(
166
+ text=text.unsqueeze(0),
167
+ duration=int(duration * 5),
168
+ style_prompt=style_prompt.unsqueeze(0),
169
+ steps=sample_steps,
170
+ cfg_strength=cfg_strength,
171
+ odeint_method=odeint_method
172
+ )
173
+ latent = latent.transpose(1, 2)
174
+ audio = decoder.decode_audio(latent, overlap=5, chunk_size=20)
175
+
176
+ num_channels = 1
177
+ audio = audio.float().cpu().squeeze()[None, :]
178
+ if fake_stereo:
179
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
180
+ num_channels = 2
181
+
182
+ if file_type == 'wav':
183
+ return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time]
184
+ else:
185
+ buffer = io.BytesIO()
186
+ torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
187
+ return buffer.getvalue()
188
+
189
+ def inference_stream(
190
+ model,
191
+ decoder,
192
+ text,
193
+ style_prompt,
194
+ duration,
195
+ cfg_strength=1.0,
196
+ sample_steps=32,
197
+ fake_stereo=True,
198
+ odeint_method='euler',
199
+ file_type="wav"
200
+ ):
201
+ with torch.inference_mode():
202
+ for audio in model.sample_cache_stream(
203
+ decoder=decoder,
204
+ text=text.unsqueeze(0),
205
+ duration=int(duration * 5),
206
+ style_prompt=style_prompt.unsqueeze(0),
207
+ steps=sample_steps,
208
+ cfg_strength=cfg_strength,
209
+ chunk_size=20,
210
+ overlap=5,
211
+ odeint_method=odeint_method
212
+ ):
213
+ audio = audio.float().cpu().numpy().squeeze()[None, :]
214
+ if fake_stereo:
215
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
216
+ # encoded_audio = io.BytesIO()
217
+ # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav')
218
+ yield (decoder.h.sampling_rate, audio.T) # [channel, time]
219
+
220
+
221
+ lrc_tokenizer = None
222
+ MAX_SEED = np.iinfo(np.int32).max
223
+ device='cuda'
224
+ dtype=torch.float16
225
+ diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
226
+
227
+ # import spaces
228
+ # @spaces.GPU
229
+ def infer_music(
230
+ lrc,
231
+ current_prompt_type,
232
+ audio_prompt=None,
233
+ text_prompt=None,
234
+ seed=42,
235
+ randomize_seed=False,
236
+ steps=16,
237
+ cfg_strength=1.0,
238
+ file_type='wav',
239
+ odeint_method='euler',
240
+ device='cuda'
241
+ ):
242
+ if randomize_seed:
243
+ seed = random.randint(0, MAX_SEED)
244
+ torch.manual_seed(seed)
245
+ print(seed, current_prompt_type)
246
+ try:
247
+ lrc_prompt = parse_lyrics(lrc)
248
+ lrc_prompt = torch.tensor(sum(lrc_prompt, []), dtype=torch.long, device=device)
249
+ if current_prompt_type == "audio":
250
+ style_prompt = get_audio_prompt(mulan, audio_prompt, device, dtype)
251
+ else:
252
+ style_prompt = get_text_prompt(mulan, text_prompt, device, dtype)
253
+ except Exception as e:
254
+ raise gr.Error(f"Error: {str(e)}")
255
+ style_prompt = style_prompt.to(dtype)
256
+ generate_song = inference(
257
+ model=diffrhythm2,
258
+ decoder=decoder,
259
+ text=lrc_prompt,
260
+ style_prompt=style_prompt,
261
+ sample_steps=steps,
262
+ cfg_strength=cfg_strength,
263
+ odeint_method=odeint_method,
264
+ duration=240,
265
+ file_type=file_type
266
+ )
267
+ return generate_song
268
+ # for block in inference_stream(
269
+ # model=diffrhythm2,
270
+ # decoder=decoder,
271
+ # text=lrc_prompt,
272
+ # style_prompt=style_prompt,
273
+ # sample_steps=steps,
274
+ # cfg_strength=cfg_strength,
275
+ # odeint_method=odeint_method,
276
+ # duration=240,
277
+ # file_type=file_type
278
+ # ):
279
+ # yield block
280
+
281
+
282
+ css = """
283
+ /* 固定文本域高度并强制滚动条 */
284
+ .lyrics-scroll-box textarea {
285
+ height: 405px !important; /* 固定高度 */
286
+ max-height: 500px !important; /* 最大高度 */
287
+ overflow-y: auto !important; /* 垂直滚动 */
288
+ white-space: pre-wrap; /* 保留换行 */
289
+ line-height: 1.5; /* 行高优化 */
290
+ }
291
+
292
+ .gr-examples {
293
+ background: transparent !important;
294
+ border: 1px solid #e0e0e0 !important;
295
+ border-radius: 8px;
296
+ margin: 1rem 0 !important;
297
+ padding: 1rem !important;
298
+ }
299
+
300
+ """
301
+ import base64
302
+
303
+ def image_to_base64(path):
304
+ with open(path, "rb") as image_file:
305
+ return base64.b64encode(image_file.read()).decode('utf-8')
306
+
307
+ with gr.Blocks(css=css) as demo:
308
+ gr.HTML(f"""
309
+ <div style="flex: 1; text-align: center;">
310
+ <div style="font-size: 2em; font-weight: bold; text-align: center; margin-bottom: 5px">
311
+ Di♪♪Rhythm 2 (谛韵)
312
+ </div>
313
+ <div style="display:flex; justify-content: center; column-gap:4px;">
314
+ <a href="https://arxiv.org/pdf/2510.22950">
315
+ <img src='https://img.shields.io/badge/Arxiv-Paper-blue'>
316
+ </a>
317
+ <a href="https://github.com/ASLP-lab/DiffRhythm2">
318
+ <img src='https://img.shields.io/badge/GitHub-Repo-green'>
319
+ </a>
320
+ <a href="https://aslp-lab.github.io/DiffRhythm2.github.io/">
321
+ <img src='https://img.shields.io/badge/Project-Page-brown'>
322
+ </a>
323
+ </div>
324
+ </div>
325
+ """)
326
+
327
+ with gr.Tabs() as tabs:
328
+
329
+ # page 1
330
+ with gr.Tab("Music Generate", id=0):
331
+ with gr.Row():
332
+ with gr.Column():
333
+ lrc = gr.Textbox(
334
+ label="Lyrics",
335
+ placeholder="Input the full lyrics",
336
+ lines=12,
337
+ max_lines=50,
338
+ elem_classes="lyrics-scroll-box",
339
+ value="""[start]
340
+ [intro]
341
+ [verse]
342
+ Thought I heard your voice yesterday
343
+ When I turned around to say
344
+ That I loved you baby
345
+ I realize it was juss my mind
346
+ Played tricks on me
347
+ And it seems colder lately at night
348
+ And I try to sleep with the lights on
349
+ Every time the phone rings
350
+ I pray to God it's you
351
+ And I just can't believe
352
+ That we're through
353
+ [chorus]
354
+ I miss you
355
+ There's no other way to say it
356
+ And I can't deny it
357
+ I miss you
358
+ It's so easy to see
359
+ I miss you and me
360
+ [verse]
361
+ Is it turning over this time
362
+ Have we really changed our minds about each other's love
363
+ All the feelings that we used to share
364
+ I refuse to believe
365
+ That you don't care
366
+ [chorus]
367
+ I miss you
368
+ There's no other way to say it
369
+ And I and I can't deny it
370
+ I miss you
371
+ [verse]
372
+ It's so easy to see
373
+ I've got to gather myself as together
374
+ I've been through worst kinds of weather
375
+ If it's over now
376
+ [outro]"""
377
+ )
378
+ current_prompt_type = gr.State(value="text")
379
+ with gr.Tabs() as inside_tabs:
380
+ with gr.Tab("Text Prompt"):
381
+ text_prompt = gr.Textbox(
382
+ label="Text Prompt",
383
+ value="Pop, Piano, Bass, Drums, Happy",
384
+ placeholder="Enter the Text Prompt, eg: emotional piano pop",
385
+ )
386
+ with gr.Tab("Audio Prompt"):
387
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
388
+
389
+ def update_prompt_type(evt: gr.SelectData):
390
+ return "text" if evt.index == 0 else "audio"
391
+
392
+ inside_tabs.select(
393
+ fn=update_prompt_type,
394
+ outputs=current_prompt_type
395
+ )
396
+
397
+
398
+ with gr.Column():
399
+
400
+ with gr.Accordion("Best Practices Guide", open=True):
401
+ gr.Markdown("""
402
+ 1. **Lyrics Format Requirements**
403
+ - Each line must follow: `Lyric content`
404
+ - Example of valid format:
405
+ ```
406
+ [intro]
407
+ [verse]
408
+ Thought I heard your voice yesterday
409
+ When I turned around to say
410
+ ```
411
+
412
+ 2. **Audio Prompt Requirements**
413
+ - Reference audio should be ≥ 1 second, Audio >10 seconds will be randomly clipped into 10 seconds
414
+ - For optimal results, the 10-second clips should be carefully selected
415
+ - Shorter clips may lead to incoherent generation
416
+
417
+ 3. **Supported Languages**
418
+ - Chinese and English
419
+ """)
420
+ lyrics_btn = gr.Button("Generate", variant="primary")
421
+ # audio_output = gr.Gallery(label="Audio Results")
422
+ audio_output = gr.Audio(label="Audio Result", elem_id="audio_output")
423
+ with gr.Accordion("Advanced Settings", open=False):
424
+ seed = gr.Slider(
425
+ label="Seed",
426
+ minimum=0,
427
+ maximum=MAX_SEED,
428
+ step=1,
429
+ value=0,
430
+ )
431
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
432
+
433
+ steps = gr.Slider(
434
+ minimum=10,
435
+ maximum=100,
436
+ value=16,
437
+ step=1,
438
+ label="Diffusion Steps",
439
+ interactive=True,
440
+ elem_id="step_slider"
441
+ )
442
+ cfg_strength = gr.Slider(
443
+ minimum=1,
444
+ maximum=10,
445
+ value=1.0,
446
+ step=0.5,
447
+ label="CFG Strength",
448
+ interactive=True,
449
+ elem_id="step_slider"
450
+ )
451
+
452
+ odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
453
+ file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
454
+
455
+
456
+ # gr.Examples(
457
+ # examples=[
458
+ # ["src/prompt/classic_cn.wav"],
459
+ # ["src/prompt/classic_en.wav"],
460
+ # ["src/prompt/country_cn.wav"],
461
+ # ["src/prompt/country_en.wav"],
462
+ # ["src/prompt/jazz_cn.wav"],
463
+ # ["src/prompt/jazz_en.wav"],
464
+ # ["src/prompt/pop_cn.wav"],
465
+ # ["src/prompt/pop_en.wav"],
466
+ # ["src/prompt/rap_cn.wav"],
467
+ # ["src/prompt/rap_en.wav"],
468
+ # ["src/prompt/rock_cn.wav"],
469
+ # ["src/prompt/rock_en.wav"]
470
+ # ],
471
+ # inputs=[audio_prompt],
472
+ # label="Audio Examples",
473
+ # examples_per_page=12,
474
+ # elem_id="audio-examples-container"
475
+ # )
476
+
477
+ # gr.Examples(
478
+ # examples=[
479
+ # ["Pop Emotional Piano"],
480
+ # ["流行 情感 钢琴"],
481
+ # ["Indie folk ballad, coming-of-age themes, acoustic guitar picking with harmonica interludes"],
482
+ # ["独立民谣, 成长主题, 原声吉他弹奏与口琴间奏"]
483
+ # ],
484
+ # inputs=[text_prompt],
485
+ # label="Text Examples",
486
+ # examples_per_page=4,
487
+ # elem_id="text-examples-container"
488
+ # )
489
+
490
+ # gr.Examples(
491
+ # examples=[
492
+ # ["""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""],
493
+ # ["""[00:05.00]Stardust whispers in your eyes\n[00:09.30]Moonlight paints our silhouettes\n[00:13.75]Tides bring secrets from the deep\n[00:18.20]Where forever's breath is kept\n[00:22.90]We dance through constellations' maze\n[00:27.15]Footprints melt in cosmic waves\n[00:31.65]Horizons hum our silent vow\n[00:36.10]Time unravels here and now\n[00:40.85]Eternal embers in the night oh oh oh\n[00:45.25]Healing scars with liquid light\n[00:49.70]Galaxies write our refrain\n[00:54.15]Love reborn in endless rain\n[01:15.30]Paper boats of memories\n[01:19.75]Float through veins of ancient trees\n[01:24.20]Your laughter spins aurora threads\n[01:28.65]Weaving dawn through featherbed"""],
494
+ # ["""[00:04.27]只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n[01:36.35]你到底属于谁\n[01:37.65]就是现在告诉我\n[01:40.00]跟着那节奏 缓缓 make wave\n"""],
495
+ # ["""[00:16.55]倦鸟西归 竹影余晖\n[00:23.58]禅意心扉\n[00:27.32]待清风 拂开一池春水\n[00:30.83]你的手绘 玉色难褪\n[00:37.99]我端详飘散的韵味\n[00:40.65]落款壶底的名讳\n[00:42.92]如吻西施的嘴\n[00:45.14]风雅几回 总相随\n[00:52.32]皆因你珍贵\n[00:57.85]三千弱水 煮一杯\n[01:02.21]我只饮下你的美\n[01:04.92]千年余味 紫砂壶伴我醉\n[01:09.73]酿一世无悔\n[01:12.09]沏壶春水 翠烟飞\n[01:16.62]把盏不尽你的香味\n[01:20.06]邀月相对 愿今生同宿同归\n[01:26.43]只让你陪\n[01:46.12]茗香芳菲 世俗无追\n"""]
496
+ # ],
497
+ # inputs=[lrc],
498
+ # label="Lrc Examples",
499
+ # examples_per_page=4,
500
+ # elem_id="lrc-examples-container",
501
+ # )
502
+
503
+ tabs.select(
504
+ lambda s: None,
505
+ None,
506
+ None
507
+ )
508
+
509
+ # TODO add max_frames parameter for infer_music
510
+ lyrics_btn.click(
511
+ fn=infer_music,
512
+ inputs=[
513
+ lrc,
514
+ current_prompt_type,
515
+ audio_prompt,
516
+ text_prompt,
517
+ seed,
518
+ randomize_seed,
519
+ steps,
520
+ cfg_strength,
521
+ file_type,
522
+ odeint_method,
523
+ ],
524
+ outputs=audio_output,
525
+ )
526
+
527
+
528
+ # demo.queue().launch(show_api=False, show_error=True)
529
+
530
+
531
+
532
+ if __name__ == "__main__":
533
+ demo.launch()
bigvgan/__init__.py ADDED
File without changes
bigvgan/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # Initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # Linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake ∶= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # Initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # Linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
bigvgan/alias_free_activation/cuda/__init__.py ADDED
File without changes
bigvgan/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
bigvgan/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
bigvgan/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
bigvgan/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
bigvgan/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
bigvgan/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
bigvgan/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
bigvgan/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
bigvgan/model.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+ from safetensors.torch import load_file
17
+
18
+ from .activations import Snake, SnakeBeta
19
+ from .utils import init_weights, get_padding
20
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
21
+ from .env import AttrDict
22
+
23
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
24
+
25
+
26
+ def load_hparams_from_json(path) -> AttrDict:
27
+ with open(path) as f:
28
+ data = f.read()
29
+ return AttrDict(json.loads(data))
30
+
31
+
32
+ class AMPBlock1(torch.nn.Module):
33
+ """
34
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
35
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
36
+
37
+ Args:
38
+ h (AttrDict): Hyperparameters.
39
+ channels (int): Number of convolution channels.
40
+ kernel_size (int): Size of the convolution kernel. Default is 3.
41
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
42
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ h: AttrDict,
48
+ channels: int,
49
+ kernel_size: int = 3,
50
+ dilation: tuple = (1, 3, 5),
51
+ activation: str = None,
52
+ ):
53
+ super().__init__()
54
+
55
+ self.h = h
56
+
57
+ self.convs1 = nn.ModuleList(
58
+ [
59
+ weight_norm(
60
+ Conv1d(
61
+ channels,
62
+ channels,
63
+ kernel_size,
64
+ stride=1,
65
+ dilation=d,
66
+ padding=get_padding(kernel_size, d),
67
+ )
68
+ )
69
+ for d in dilation
70
+ ]
71
+ )
72
+ self.convs1.apply(init_weights)
73
+
74
+ self.convs2 = nn.ModuleList(
75
+ [
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ stride=1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ )
86
+ for _ in range(len(dilation))
87
+ ]
88
+ )
89
+ self.convs2.apply(init_weights)
90
+
91
+ self.num_layers = len(self.convs1) + len(
92
+ self.convs2
93
+ ) # Total number of conv layers
94
+
95
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
96
+ if self.h.get("use_cuda_kernel", False):
97
+ from alias_free_activation.cuda.activation1d import (
98
+ Activation1d as CudaActivation1d,
99
+ )
100
+
101
+ Activation1d = CudaActivation1d
102
+ else:
103
+ Activation1d = TorchActivation1d
104
+
105
+ # Activation functions
106
+ if activation == "snake":
107
+ self.activations = nn.ModuleList(
108
+ [
109
+ Activation1d(
110
+ activation=Snake(
111
+ channels, alpha_logscale=h.snake_logscale
112
+ )
113
+ )
114
+ for _ in range(self.num_layers)
115
+ ]
116
+ )
117
+ elif activation == "snakebeta":
118
+ self.activations = nn.ModuleList(
119
+ [
120
+ Activation1d(
121
+ activation=SnakeBeta(
122
+ channels, alpha_logscale=h.snake_logscale
123
+ )
124
+ )
125
+ for _ in range(self.num_layers)
126
+ ]
127
+ )
128
+ else:
129
+ raise NotImplementedError(
130
+ "activation incorrectly specified. check the config file and look for 'activation'."
131
+ )
132
+
133
+ def forward(self, x):
134
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
135
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
136
+ xt = a1(x)
137
+ xt = c1(xt)
138
+ xt = a2(xt)
139
+ xt = c2(xt)
140
+ x = xt + x
141
+
142
+ return x
143
+
144
+ def remove_weight_norm(self):
145
+ for l in self.convs1:
146
+ remove_weight_norm(l)
147
+ for l in self.convs2:
148
+ remove_weight_norm(l)
149
+
150
+
151
+ class AMPBlock2(torch.nn.Module):
152
+ """
153
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
154
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
155
+
156
+ Args:
157
+ h (AttrDict): Hyperparameters.
158
+ channels (int): Number of convolution channels.
159
+ kernel_size (int): Size of the convolution kernel. Default is 3.
160
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
161
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ h: AttrDict,
167
+ channels: int,
168
+ kernel_size: int = 3,
169
+ dilation: tuple = (1, 3, 5),
170
+ activation: str = None,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.h = h
175
+
176
+ self.convs = nn.ModuleList(
177
+ [
178
+ weight_norm(
179
+ Conv1d(
180
+ channels,
181
+ channels,
182
+ kernel_size,
183
+ stride=1,
184
+ dilation=d,
185
+ padding=get_padding(kernel_size, d),
186
+ )
187
+ )
188
+ for d in dilation
189
+ ]
190
+ )
191
+ self.convs.apply(init_weights)
192
+
193
+ self.num_layers = len(self.convs) # Total number of conv layers
194
+
195
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
196
+ if self.h.get("use_cuda_kernel", False):
197
+ from alias_free_activation.cuda.activation1d import (
198
+ Activation1d as CudaActivation1d,
199
+ )
200
+
201
+ Activation1d = CudaActivation1d
202
+ else:
203
+ Activation1d = TorchActivation1d
204
+
205
+ # Activation functions
206
+ if activation == "snake":
207
+ self.activations = nn.ModuleList(
208
+ [
209
+ Activation1d(
210
+ activation=Snake(
211
+ channels, alpha_logscale=h.snake_logscale
212
+ )
213
+ )
214
+ for _ in range(self.num_layers)
215
+ ]
216
+ )
217
+ elif activation == "snakebeta":
218
+ self.activations = nn.ModuleList(
219
+ [
220
+ Activation1d(
221
+ activation=SnakeBeta(
222
+ channels, alpha_logscale=h.snake_logscale
223
+ )
224
+ )
225
+ for _ in range(self.num_layers)
226
+ ]
227
+ )
228
+ else:
229
+ raise NotImplementedError(
230
+ "activation incorrectly specified. check the config file and look for 'activation'."
231
+ )
232
+
233
+ def forward(self, x):
234
+ for c, a in zip(self.convs, self.activations):
235
+ xt = a(x)
236
+ xt = c(xt)
237
+ x = xt + x
238
+ return x
239
+
240
+ def remove_weight_norm(self):
241
+ for l in self.convs:
242
+ remove_weight_norm(l)
243
+
244
+
245
+ class BigVGAN(
246
+ torch.nn.Module,
247
+ PyTorchModelHubMixin,
248
+ library_name="bigvgan",
249
+ repo_url="https://github.com/NVIDIA/BigVGAN",
250
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
251
+ pipeline_tag="audio-to-audio",
252
+ license="mit",
253
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
254
+ ):
255
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
256
+ super().__init__()
257
+ self.h = h
258
+ self.h["use_cuda_kernel"] = use_cuda_kernel
259
+
260
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
261
+ if self.h.get("use_cuda_kernel", False):
262
+ from alias_free_activation.cuda.activation1d import (
263
+ Activation1d as CudaActivation1d,
264
+ )
265
+
266
+ Activation1d = CudaActivation1d
267
+ else:
268
+ Activation1d = TorchActivation1d
269
+
270
+ self.num_kernels = len(h.resblock_kernel_sizes)
271
+ self.num_upsamples = len(h.upsample_rates)
272
+
273
+ # Pre-conv
274
+ self.conv_pre = weight_norm(
275
+ Conv1d(h.in_channels, h.upsample_initial_channel, 7, 1, padding=3)
276
+ )
277
+
278
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
279
+ if h.resblock == "1":
280
+ resblock_class = AMPBlock1
281
+ elif h.resblock == "2":
282
+ resblock_class = AMPBlock2
283
+ else:
284
+ raise ValueError(
285
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
286
+ )
287
+
288
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
289
+ self.ups = nn.ModuleList()
290
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
291
+ self.ups.append(
292
+ nn.ModuleList(
293
+ [
294
+ weight_norm(
295
+ ConvTranspose1d(
296
+ h.upsample_initial_channel // (2**i),
297
+ h.upsample_initial_channel // (2 ** (i + 1)),
298
+ k,
299
+ u,
300
+ padding=(k - u) // 2,
301
+ )
302
+ )
303
+ ]
304
+ )
305
+ )
306
+
307
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
308
+ self.resblocks = nn.ModuleList()
309
+ for i in range(len(self.ups)):
310
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
311
+ for j, (k, d) in enumerate(
312
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
313
+ ):
314
+ self.resblocks.append(
315
+ resblock_class(h, ch, k, d, activation=h.activation)
316
+ )
317
+
318
+ # Post-conv
319
+ activation_post = (
320
+ Snake(ch, alpha_logscale=h.snake_logscale)
321
+ if h.activation == "snake"
322
+ else (
323
+ SnakeBeta(ch, alpha_logscale=h.snake_logscale)
324
+ if h.activation == "snakebeta"
325
+ else None
326
+ )
327
+ )
328
+ if activation_post is None:
329
+ raise NotImplementedError(
330
+ "activation incorrectly specified. check the config file and look for 'activation'."
331
+ )
332
+
333
+ self.activation_post = Activation1d(activation=activation_post)
334
+
335
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
336
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
337
+ self.conv_post = weight_norm(
338
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
339
+ )
340
+
341
+ # Weight initialization
342
+ for i in range(len(self.ups)):
343
+ self.ups[i].apply(init_weights)
344
+ self.conv_post.apply(init_weights)
345
+
346
+ # Final tanh activation. Defaults to True for backward compatibility
347
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
348
+
349
+ def forward(self, x):
350
+ # Pre-conv
351
+ x = self.conv_pre(x)
352
+
353
+ for i in range(self.num_upsamples):
354
+ # Upsampling
355
+ for i_up in range(len(self.ups[i])):
356
+ x = self.ups[i][i_up](x)
357
+ # AMP blocks
358
+ xs = None
359
+ for j in range(self.num_kernels):
360
+ if xs is None:
361
+ xs = self.resblocks[i * self.num_kernels + j](x)
362
+ else:
363
+ xs += self.resblocks[i * self.num_kernels + j](x)
364
+ x = xs / self.num_kernels
365
+
366
+ # Post-conv
367
+ x = self.activation_post(x)
368
+ x = self.conv_post(x)
369
+ # Final tanh activation
370
+ if self.use_tanh_at_final:
371
+ x = torch.tanh(x)
372
+ else:
373
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
374
+
375
+ return x
376
+
377
+ def remove_weight_norm(self):
378
+ try:
379
+ print("Removing weight norm...")
380
+ for l in self.ups:
381
+ for l_i in l:
382
+ remove_weight_norm(l_i)
383
+ for l in self.resblocks:
384
+ l.remove_weight_norm()
385
+ remove_weight_norm(self.conv_pre)
386
+ remove_weight_norm(self.conv_post)
387
+ except ValueError:
388
+ print("[INFO] Model already removed weight norm. Skipping!")
389
+ pass
390
+
391
+ # Additional methods for huggingface_hub support
392
+ def _save_pretrained(self, save_directory: Path) -> None:
393
+ """Save weights and config.json from a Pytorch model to a local directory."""
394
+
395
+ model_path = save_directory / "bigvgan_generator.pt"
396
+ torch.save({"generator": self.state_dict()}, model_path)
397
+
398
+ config_path = save_directory / "config.json"
399
+ with open(config_path, "w") as config_file:
400
+ json.dump(self.h, config_file, indent=4)
401
+
402
+ @classmethod
403
+ def _from_pretrained(
404
+ cls,
405
+ *,
406
+ model_id: str,
407
+ revision: str,
408
+ cache_dir: str,
409
+ force_download: bool,
410
+ proxies: Optional[Dict],
411
+ resume_download: bool,
412
+ local_files_only: bool,
413
+ token: Union[str, bool, None],
414
+ map_location: str = "cpu", # Additional argument
415
+ strict: bool = False, # Additional argument
416
+ use_cuda_kernel: bool = False,
417
+ **model_kwargs,
418
+ ):
419
+ """Load Pytorch pretrained weights and return the loaded model."""
420
+
421
+ # Download and load hyperparameters (h) used by BigVGAN
422
+ if os.path.isdir(model_id):
423
+ print("Loading config.json from local directory")
424
+ config_file = os.path.join(model_id, "config.json")
425
+ else:
426
+ config_file = hf_hub_download(
427
+ repo_id=model_id,
428
+ filename="config.json",
429
+ revision=revision,
430
+ cache_dir=cache_dir,
431
+ force_download=force_download,
432
+ proxies=proxies,
433
+ resume_download=resume_download,
434
+ token=token,
435
+ local_files_only=local_files_only,
436
+ )
437
+ h = load_hparams_from_json(config_file)
438
+
439
+ # instantiate BigVGAN using h
440
+ if use_cuda_kernel:
441
+ print(
442
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
443
+ )
444
+ print(
445
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
446
+ )
447
+ print(
448
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
449
+ )
450
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
451
+
452
+ # Download and load pretrained generator weight
453
+ if os.path.isdir(model_id):
454
+ print("Loading weights from local directory")
455
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
456
+ else:
457
+ print(f"Loading weights from {model_id}")
458
+ model_file = hf_hub_download(
459
+ repo_id=model_id,
460
+ filename="bigvgan_generator.pt",
461
+ revision=revision,
462
+ cache_dir=cache_dir,
463
+ force_download=force_download,
464
+ proxies=proxies,
465
+ resume_download=resume_download,
466
+ token=token,
467
+ local_files_only=local_files_only,
468
+ )
469
+
470
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
471
+
472
+ try:
473
+ model.load_state_dict(checkpoint_dict["generator"])
474
+ except RuntimeError:
475
+ print(
476
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
477
+ )
478
+ model.remove_weight_norm()
479
+ model.load_state_dict(checkpoint_dict["generator"])
480
+
481
+ return model
482
+
483
+
484
+ class Generator(torch.nn.Module):
485
+ def __init__(self, config_file, ckpt_path):
486
+ super().__init__()
487
+ with open(config_file) as f:
488
+ json_config = json.load(f)
489
+ self.h = AttrDict(json_config)
490
+ self.decoder = BigVGAN(self.h)
491
+ if ckpt_path.endswith(".safetensors"):
492
+ checkpoint_dict = load_file(ckpt_path)
493
+ else:
494
+ checkpoint_dict = torch.load(ckpt_path, map_location='cpu')
495
+ self.decoder.load_state_dict(checkpoint_dict["generator"])
496
+ self.decoder.remove_weight_norm()
497
+ self.decoder.eval()
498
+
499
+ def decode_audio(self, latents, overlap=5, chunk_size=20):
500
+ # chunked decoding
501
+ hop_size = chunk_size - overlap
502
+ total_size = latents.shape[2]
503
+ batch_size = latents.shape[0]
504
+ chunks = []
505
+ for i in range(0, total_size - chunk_size + 1, hop_size):
506
+ chunk = latents[:,:,i:i+chunk_size]
507
+ chunks.append(chunk)
508
+ if i+chunk_size != total_size:
509
+ # Final chunk
510
+ chunk = latents[:,:,-chunk_size:]
511
+ chunks.append(chunk)
512
+ chunks = torch.stack(chunks)
513
+ num_chunks = chunks.shape[0]
514
+ # samples_per_latent is just the downsampling ratio
515
+ samples_per_latent = 9600
516
+ # Create an empty waveform, we will populate it with chunks as decode them
517
+ y_size = total_size * samples_per_latent
518
+ y_final = torch.zeros((batch_size,1,y_size)).to(latents.device)
519
+ for i in range(num_chunks):
520
+ x_chunk = chunks[i,:]
521
+ # decode the chunk
522
+ y_chunk = self.decoder(x_chunk)
523
+ # figure out where to put the audio along the time domain
524
+ if i == num_chunks-1:
525
+ # final chunk always goes at the end
526
+ t_end = y_size
527
+ t_start = t_end - y_chunk.shape[2]
528
+ else:
529
+ t_start = i * hop_size * samples_per_latent
530
+ t_end = t_start + chunk_size * samples_per_latent
531
+ # remove the edges of the overlaps
532
+ ol = (overlap//2) * samples_per_latent
533
+ chunk_start = 0
534
+ chunk_end = y_chunk.shape[2]
535
+ if i > 0:
536
+ # no overlap for the start of the first chunk
537
+ t_start += ol
538
+ chunk_start += ol
539
+ if i < num_chunks-1:
540
+ # no overlap for the end of the last chunk
541
+ t_end -= ol
542
+ chunk_end -= ol
543
+ # paste the chunked audio into our y_final output audio
544
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
545
+ return y_final
bigvgan/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import torch
7
+ from torch.nn.utils import weight_norm
8
+
9
+
10
+
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+
17
+ def apply_weight_norm(m):
18
+ classname = m.__class__.__name__
19
+ if classname.find("Conv") != -1:
20
+ weight_norm(m)
21
+
22
+
23
+ def get_padding(kernel_size, dilation=1):
24
+ return int((kernel_size * dilation - dilation) / 2)
25
+
26
+
27
+ def load_checkpoint(filepath, device):
28
+ assert os.path.isfile(filepath)
29
+ print(f"Loading '{filepath}'")
30
+ checkpoint_dict = torch.load(filepath, map_location=device)
31
+ print("Complete.")
32
+ return checkpoint_dict
33
+
34
+
35
+ def save_checkpoint(filepath, obj):
36
+ print(f"Saving checkpoint to {filepath}")
37
+ torch.save(obj, filepath)
38
+ print("Complete.")
39
+
40
+
41
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
42
+ # Fallback to original scanning logic first
43
+ pattern = os.path.join(cp_dir, prefix + "????????")
44
+ cp_list = glob.glob(pattern)
45
+
46
+ if len(cp_list) > 0:
47
+ last_checkpoint_path = sorted(cp_list)[-1]
48
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
49
+ return last_checkpoint_path
50
+
51
+ # If no pattern-based checkpoints are found, check for renamed file
52
+ if renamed_file:
53
+ renamed_path = os.path.join(cp_dir, renamed_file)
54
+ if os.path.isfile(renamed_path):
55
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
56
+ return renamed_path
57
+
58
+ return None
59
+
diffrhythm2/__init__.py ADDED
File without changes
diffrhythm2/backbones/__init__.py ADDED
File without changes
diffrhythm2/backbones/dit.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ import math
20
+ from torch import nn
21
+
22
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig
23
+ from .llama_nar import LlamaNARDecoderLayer
24
+
25
+ class TextEmbedding(nn.Module):
26
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
27
+ super().__init__()
28
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
29
+
30
+ def forward(self, text: int["b nt"]): # noqa: F722
31
+ text = self.text_embed(text) # b n -> b n d
32
+ return text
33
+
34
+
35
+ class InputEmbedding(nn.Module):
36
+ def __init__(self, cond_dim, out_dim):
37
+ super().__init__()
38
+ self.proj = nn.Linear(cond_dim, cond_dim)
39
+ self.proj_2 = nn.Linear(cond_dim, out_dim)
40
+
41
+ def forward(self, x, style_emb, time_emb): # noqa: F722
42
+ style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
43
+ x_orig = x
44
+ x = x + style_emb + time_emb
45
+ x = self.proj(x) + x_orig
46
+ x = self.proj_2(x)
47
+ return x
48
+
49
+
50
+ class AdaLayerNormZero_Final(nn.Module):
51
+ def __init__(self, dim, cond_dim):
52
+ super().__init__()
53
+
54
+ self.silu = nn.SiLU()
55
+ self.linear = nn.Linear(cond_dim, dim * 2)
56
+
57
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
58
+
59
+ def forward(self, x, emb):
60
+ emb = self.linear(self.silu(emb))
61
+ scale, shift = torch.chunk(emb, 2, dim=-1)
62
+
63
+ x = self.norm(x) * (1 + scale) + shift
64
+ return x
65
+
66
+
67
+ class SinusPositionEmbedding(nn.Module):
68
+ def __init__(self, dim):
69
+ super().__init__()
70
+ self.dim = dim
71
+
72
+ def forward(self, x, scale=1000):
73
+ device = x.device
74
+ half_dim = self.dim // 2
75
+ emb = math.log(10000) / (half_dim - 1)
76
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
77
+ emb = scale * x.unsqueeze(-1) * emb.unsqueeze(0).unsqueeze(0)
78
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
79
+ return emb
80
+
81
+ def numel(self):
82
+ return 0
83
+
84
+
85
+ class TimestepEmbedding(nn.Module):
86
+ def __init__(self, dim, freq_embed_dim=256):
87
+ super().__init__()
88
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
89
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
90
+
91
+ def forward(self, timestep: float["b"]): # noqa: F821
92
+ time_hidden = self.time_embed(timestep)
93
+ time_hidden = time_hidden.to(timestep.dtype)
94
+ time = self.time_mlp(time_hidden) # b d
95
+ return time
96
+
97
+
98
+ class DiT(nn.Module):
99
+ def __init__(
100
+ self,
101
+ *,
102
+ dim,
103
+ depth=8,
104
+ heads=8,
105
+ ff_mult=4,
106
+ mel_dim=100,
107
+ text_num_embeds=256,
108
+ conv_layers=0,
109
+ long_skip_connection=False,
110
+ use_flex_attn=False,
111
+ repa_depth=-1,
112
+ repa_dims=[1024],
113
+ **kwargs
114
+ ):
115
+ super().__init__()
116
+
117
+ cond_dim = 512
118
+ self.time_embed = TimestepEmbedding(cond_dim)
119
+ self.text_embed = TextEmbedding(text_num_embeds, cond_dim, conv_layers=conv_layers)
120
+ self.input_embed = InputEmbedding(cond_dim, dim)
121
+
122
+ self.latent_embed = torch.nn.Sequential(
123
+ nn.Linear(mel_dim, cond_dim),
124
+ nn.Linear(cond_dim, cond_dim)
125
+ )
126
+
127
+ self.dim = dim
128
+ self.depth = depth
129
+ self.use_flex_attn = use_flex_attn
130
+
131
+ llama_config = LlamaConfig(
132
+ hidden_size=dim,
133
+ num_attention_heads=heads,
134
+ intermediate_size=dim * ff_mult,
135
+ hidden_act='silu',
136
+ max_position_embeddings=4096
137
+ )
138
+ self.rotary_embed = LlamaRotaryEmbedding(config=llama_config)
139
+ llama_config._attn_implementation = 'sdpa'
140
+ self.transformer_blocks = nn.ModuleList(
141
+ [LlamaNARDecoderLayer(llama_config, layer_idx=i, use_flex_attn=self.use_flex_attn) for i in range(depth)]
142
+ )
143
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
144
+
145
+
146
+ self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
147
+ self.proj_out = nn.Linear(dim, mel_dim)
148
+
149
+ self.repa_depth = repa_depth
150
+ self.repa_dims = repa_dims
151
+ self.projectors = None
152
+ if self.repa_depth > 0:
153
+ self.projectors = nn.ModuleList([
154
+ nn.Sequential(
155
+ nn.Linear(self.dim, self.dim * 2),
156
+ nn.SiLU(),
157
+ nn.Linear(self.dim * 2, self.dim * 2),
158
+ nn.SiLU(),
159
+ nn.Linear(self.dim * 2, repa_dim),
160
+ ) for repa_dim in self.repa_dims
161
+ ])
162
+
163
+
164
+ def forward(
165
+ self,
166
+ x: torch.Tensor,
167
+ time: torch.Tensor,
168
+ position_ids: torch.Tensor,
169
+ style_prompt: torch.Tensor,
170
+ attn_mask: torch.Tensor,
171
+ output_attentions: bool = False,
172
+ use_cache: bool = False,
173
+ past_key_value = None,
174
+ ):
175
+ """
176
+ Args:
177
+ x: [b, n, d]
178
+ time: [b, n, 1]
179
+ position_ids: [b, n]
180
+ style_prompt: [b, 512]
181
+ attn_mask: [b, 1, n, n]
182
+ """
183
+ batch, seq_len = x.shape[0], x.shape[1]
184
+ t = self.time_embed(time)
185
+ c = t # [B, T, dim]
186
+
187
+ x = self.input_embed(x, style_prompt, c)
188
+
189
+ if self.long_skip_connection is not None:
190
+ residual = x
191
+
192
+ position_embeddings = self.rotary_embed(x, position_ids)
193
+
194
+ attn_weights = []
195
+ if not use_cache:
196
+ past_key_value = None
197
+
198
+ repa_res = None
199
+ for i, block in enumerate(self.transformer_blocks):
200
+ res = block(
201
+ x,
202
+ attention_mask=attn_mask,
203
+ position_embeddings=position_embeddings,
204
+ output_attentions=output_attentions,
205
+ past_key_value=past_key_value,
206
+ use_cache=use_cache
207
+ )
208
+ x = res.pop(0)
209
+ if output_attentions:
210
+ attn_weights.append(res.pop(0))
211
+ if use_cache:
212
+ past_key_value = res.pop(0)
213
+ if i == self.repa_depth - 1:
214
+ repa_res = x
215
+
216
+ if self.long_skip_connection is not None:
217
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
218
+
219
+ x = self.norm_out(x, c)
220
+ output = self.proj_out(x)
221
+
222
+ return output, attn_weights, past_key_value
diffrhythm2/backbones/flex_attention.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
21
+ from torch.nn.attention.flex_attention import (
22
+ create_block_mask as create_block_causal_mask_flex,
23
+ )
24
+
25
+ class WrappedFlexAttention:
26
+ """
27
+ We are doing a singleton class so that flex attention is compiled once when it's first called.
28
+ """
29
+
30
+ _instance = None
31
+ _is_flex_compiled = False
32
+ _compiled_flex_attention = None
33
+
34
+ def __new__(cls, *args, **kwargs):
35
+ if cls._instance is None:
36
+ # Create a new instance if one doesn't already exist
37
+ cls._instance = super().__new__(cls)
38
+ return cls._instance
39
+
40
+ @torch.compiler.disable(recursive=False)
41
+ def __init__(self, training):
42
+ """
43
+ Initialize or update the singleton instance.
44
+ """
45
+ if not self._is_flex_compiled or training != self.training:
46
+ # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
47
+ # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
48
+ # see https://github.com/pytorch/pytorch/issues/146260 for training
49
+ self.training = training
50
+ if torch.__version__.split("+")[0] == "2.6.0" and training:
51
+ self._compiled_flex_attention = torch.compile(
52
+ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
53
+ )
54
+ else:
55
+ self._compiled_flex_attention = torch.compile(flex_attention)
56
+ self._is_flex_compiled = True
57
+
58
+ def __call__(self):
59
+ return self._compiled_flex_attention
60
+
61
+
62
+ Offset = Union[torch.Tensor, int]
63
+
64
+
65
+ def make_flex_block_causal_mask(
66
+ attention_mask_2d: torch.Tensor,
67
+ attention_chunk_size: Optional[int] = None,
68
+ query_length=None,
69
+ key_length=None,
70
+ offsets: Optional[Tuple[Offset, Offset]] = None,
71
+ ) -> "BlockMask":
72
+ """
73
+ Create a block causal document mask for a batch of sequences, both packed and unpacked.
74
+ Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
75
+ The resultant BlockMask is a compressed representation of the full block causal
76
+ mask. BlockMask is essential for performant computation of flex attention.
77
+ See: https://pytorch.org/blog/flexattention/
78
+
79
+ Args:
80
+ attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
81
+ of shape (batch_size, total_seq_len). e.g.
82
+
83
+ For unpacked sequence:
84
+ [[1, 1, 1, 1, 0, 0, 0],
85
+ [1, 1, 1, 1, 1, 0, 0]]
86
+
87
+ For packed sequence:
88
+ [[1, 1, 1, 2, 2, 2, 0],
89
+ [1, 1, 2, 2, 2, 3, 3]]
90
+
91
+ Returns:
92
+ BlockMask
93
+ """
94
+ batch_size, total_seq_len = attention_mask_2d.shape
95
+ if not key_length:
96
+ key_length = total_seq_len
97
+ if not query_length:
98
+ query_length = total_seq_len
99
+ attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
100
+ device = attention_mask_2d.device
101
+ document_ids = attention_mask_2d.clone()
102
+
103
+ if attention_chunk_size is not None:
104
+ # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
105
+ document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
106
+
107
+ # Instead of passing a tensor mask, flex attention requires a mask_mod function
108
+ # that determines which elements of QK^T should be included in the attention
109
+ # computation prior to the softmax. For sample packing, we need both the
110
+ # logic for both causal mask and document mask. See PyTorch's official
111
+ # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
112
+ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
113
+ """
114
+ Defines the logic of a block causal mask by combining both a standard causal mask
115
+ and a block diagonal document mask.
116
+
117
+ See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
118
+ for an illustration.
119
+ """
120
+ causal_mask = q_idx >= kv_idx # not valid when decoding
121
+ document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
122
+ padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
123
+ final_mask = causal_mask & padding_mask & document_mask
124
+ return final_mask
125
+
126
+ if offsets is not None:
127
+ q_offset = offsets[0]
128
+ kv_offset = offsets[1]
129
+
130
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
131
+ offset_q = q_idx + q_offset
132
+ offset_kv = kv_idx + kv_offset
133
+ return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
134
+ else:
135
+ mask_mod = causal_mask_mod
136
+ return create_block_causal_mask_flex(
137
+ mask_mod=mask_mod,
138
+ B=batch_size,
139
+ H=None, # attention head
140
+ Q_LEN=query_length,
141
+ KV_LEN=key_length,
142
+ device=device,
143
+ _compile=True,
144
+ )
145
+
146
+
147
+ @torch.compiler.disable(recursive=False)
148
+ def compile_friendly_flex_attention(
149
+ query: torch.Tensor,
150
+ key: torch.Tensor,
151
+ value: torch.Tensor,
152
+ training=False,
153
+ **kwargs,
154
+ ) -> torch.Tensor:
155
+ # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
156
+ flex_attention_compiled = WrappedFlexAttention(training)()
157
+ return flex_attention_compiled(
158
+ query,
159
+ key,
160
+ value,
161
+ **kwargs,
162
+ )
163
+
164
+
165
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
166
+ """
167
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
168
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
169
+ """
170
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
171
+ if n_rep == 1:
172
+ return hidden_states
173
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
174
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
175
+
176
+
177
+ def flex_attention_forward(
178
+ query: torch.Tensor,
179
+ key: torch.Tensor,
180
+ value: torch.Tensor,
181
+ attention_mask: Union[torch.Tensor, "BlockMask"],
182
+ training: bool = True,
183
+ scaling: Optional[float] = None,
184
+ softcap: Optional[float] = None,
185
+ head_mask: Optional[torch.Tensor] = None,
186
+ **kwargs,
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ block_mask = None
189
+ causal_mask = None
190
+
191
+ block_mask = attention_mask
192
+ # if isinstance(attention_mask, BlockMask):
193
+ # block_mask = attention_mask
194
+ # else:
195
+ # causal_mask = attention_mask
196
+
197
+ if causal_mask is not None:
198
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
199
+
200
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
201
+ if softcap is not None:
202
+ score = softcap * torch.tanh(score / softcap)
203
+ if causal_mask is not None:
204
+ score = score + causal_mask[batch_idx][0][q_idx][kv_idx]
205
+ if head_mask is not None:
206
+ score = score + head_mask[batch_idx][head_idx][0][0]
207
+ return score
208
+
209
+ enable_gqa = True
210
+ num_local_query_heads = query.shape[1]
211
+
212
+ # When running TP this helps:
213
+ if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
214
+ key = repeat_kv(key, query.shape[1] // key.shape[1])
215
+ value = repeat_kv(value, query.shape[1] // value.shape[1])
216
+ enable_gqa = False
217
+
218
+ kernel_options = kwargs.get("kernel_options", None)
219
+ attn_output, attention_weights = compile_friendly_flex_attention(
220
+ query,
221
+ key,
222
+ value,
223
+ score_mod=score_mod,
224
+ block_mask=block_mask,
225
+ enable_gqa=enable_gqa,
226
+ scale=scaling,
227
+ kernel_options=kernel_options,
228
+ # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
229
+ # For simplification, we thus always return it as no additional computations are introduced.
230
+ return_lse=True,
231
+ training=training,
232
+ )
233
+ # lse is returned in float32
234
+ attention_weights = attention_weights.to(value.dtype)
235
+ attn_output = attn_output.transpose(1, 2).contiguous()
236
+
237
+ return attn_output, attention_weights
diffrhythm2/backbones/llama_attention.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import math
19
+
20
+ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, LlamaRMSNorm
21
+ from transformers.models.llama.modeling_llama import Cache, StaticCache, FlashAttentionKwargs, Unpack
22
+ from transformers.models.llama.modeling_llama import (
23
+ apply_rotary_pos_emb,
24
+ repeat_kv,
25
+ _flash_attention_forward,
26
+ is_flash_attn_greater_or_equal_2_10
27
+ )
28
+ from transformers.models.llama.modeling_llama import logger
29
+ from typing import Optional, Tuple
30
+
31
+ try:
32
+ from .flex_attention import flex_attention_forward
33
+ except:
34
+ pass
35
+
36
+ class LlamaAttention(nn.Module):
37
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
38
+
39
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
40
+ super().__init__()
41
+ self.config = config
42
+ self.layer_idx = layer_idx
43
+ if layer_idx is None:
44
+ logger.warning_once(
45
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
46
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
47
+ "when creating this class."
48
+ )
49
+
50
+ self.attention_dropout = config.attention_dropout
51
+ self.hidden_size = config.hidden_size
52
+ self.num_heads = config.num_attention_heads
53
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
54
+ self.num_key_value_heads = config.num_key_value_heads
55
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
56
+ self.max_position_embeddings = config.max_position_embeddings
57
+ self.rope_theta = config.rope_theta
58
+ self.is_causal = False
59
+
60
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
61
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
62
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
63
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
64
+
65
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
66
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
67
+ self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
68
+ self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.LongTensor] = None,
75
+ past_key_value: Optional[Cache] = None,
76
+ output_attentions: bool = False,
77
+ use_cache: bool = False,
78
+ cache_position: Optional[torch.LongTensor] = None,
79
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
80
+ **kwargs,
81
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
82
+ bsz, q_len, _ = hidden_states.size()
83
+
84
+ query_states = self.q_proj(hidden_states)
85
+ key_states = self.k_proj(hidden_states)
86
+ value_states = self.v_proj(hidden_states)
87
+
88
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
89
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
90
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
91
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
92
+
93
+ query_states = self.q_norm(query_states)
94
+ key_states = self.k_norm(key_states)
95
+
96
+ if position_embeddings is None:
97
+ logger.warning_once(
98
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
99
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
100
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
101
+ "removed and `position_embeddings` will be mandatory."
102
+ )
103
+ cos, sin = self.rotary_emb(value_states, position_ids)
104
+ else:
105
+ cos, sin = position_embeddings
106
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
107
+
108
+ if past_key_value is not None:
109
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
110
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
111
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
112
+
113
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
114
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
115
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
116
+
117
+ if attention_mask is not None: # no matter the length, we just slice it
118
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
119
+ if attention_mask.dtype != torch.bool:
120
+ attn_weights = attn_weights + causal_mask
121
+ else:
122
+ attn_weights = torch.masked_fill(attn_weights, ~causal_mask, float("-inf"))
123
+
124
+ # upcast attention to fp32
125
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
126
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
127
+ attn_output = torch.matmul(attn_weights, value_states)
128
+
129
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
130
+ raise ValueError(
131
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
132
+ f" {attn_output.size()}"
133
+ )
134
+
135
+ attn_output = attn_output.transpose(1, 2).contiguous()
136
+
137
+ attn_output = attn_output.reshape(bsz, q_len, -1)
138
+
139
+ attn_output = self.o_proj(attn_output)
140
+
141
+ if not output_attentions:
142
+ attn_weights = None
143
+
144
+ return attn_output, attn_weights, past_key_value
145
+
146
+
147
+ class LlamaFlashAttention2(LlamaAttention):
148
+ """
149
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
150
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
151
+ flash attention and deal with padding tokens in case the input contains any of them.
152
+ """
153
+
154
+ def __init__(self, *args, **kwargs):
155
+ super().__init__(*args, **kwargs)
156
+
157
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
158
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
159
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
160
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.LongTensor] = None,
166
+ position_ids: Optional[torch.LongTensor] = None,
167
+ past_key_value: Optional[Cache] = None,
168
+ output_attentions: bool = False,
169
+ use_cache: bool = False,
170
+ cache_position: Optional[torch.LongTensor] = None,
171
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
172
+ **kwargs: Unpack[FlashAttentionKwargs],
173
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
174
+ if isinstance(past_key_value, StaticCache):
175
+ raise ValueError(
176
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
177
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
178
+ )
179
+
180
+ output_attentions = False
181
+
182
+ bsz, q_len, _ = hidden_states.size()
183
+
184
+ query_states = self.q_proj(hidden_states)
185
+ key_states = self.k_proj(hidden_states)
186
+ value_states = self.v_proj(hidden_states)
187
+
188
+ # Flash attention requires the input to have the shape
189
+ # batch_size x seq_length x head_dim x hidden_dim
190
+ # therefore we just need to keep the original shape
191
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
193
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
194
+
195
+ query_states = self.q_norm(query_states)
196
+ key_states = self.k_norm(key_states)
197
+
198
+ if position_embeddings is None:
199
+ logger.warning_once(
200
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
201
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
202
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
203
+ "removed and `position_embeddings` will be mandatory."
204
+ )
205
+ cos, sin = self.rotary_emb(value_states, position_ids)
206
+ else:
207
+ cos, sin = position_embeddings
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if past_key_value is not None:
211
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+
215
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
216
+ # to be able to avoid many of these transpose/reshape/view.
217
+ query_states = query_states.transpose(1, 2)
218
+ key_states = key_states.transpose(1, 2)
219
+ value_states = value_states.transpose(1, 2)
220
+
221
+ dropout_rate = self.attention_dropout if self.training else 0.0
222
+
223
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
224
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
225
+ # cast them back in the correct dtype just to be sure everything works as expected.
226
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
227
+ # in fp32. (LlamaRMSNorm handles it correctly)
228
+
229
+ input_dtype = query_states.dtype
230
+ if input_dtype == torch.float32:
231
+ if torch.is_autocast_enabled():
232
+ target_dtype = torch.get_autocast_gpu_dtype()
233
+ # Handle the case where the model is quantized
234
+ elif hasattr(self.config, "_pre_quantization_dtype"):
235
+ target_dtype = self.config._pre_quantization_dtype
236
+ else:
237
+ target_dtype = self.q_proj.weight.dtype
238
+
239
+ logger.warning_once(
240
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
241
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
242
+ f" {target_dtype}."
243
+ )
244
+
245
+ query_states = query_states.to(target_dtype)
246
+ key_states = key_states.to(target_dtype)
247
+ value_states = value_states.to(target_dtype)
248
+
249
+ attn_output = _flash_attention_forward(
250
+ query_states,
251
+ key_states,
252
+ value_states,
253
+ attention_mask,
254
+ q_len,
255
+ position_ids=position_ids,
256
+ dropout=dropout_rate,
257
+ sliding_window=getattr(self, "sliding_window", None),
258
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
259
+ is_causal=self.is_causal,
260
+ **kwargs,
261
+ )
262
+
263
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
264
+ attn_output = self.o_proj(attn_output)
265
+
266
+ if not output_attentions:
267
+ attn_weights = None
268
+
269
+ return attn_output, attn_weights, past_key_value
270
+
271
+
272
+ class LlamaSdpaAttention(LlamaAttention):
273
+ """
274
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
275
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
276
+ SDPA API.
277
+ """
278
+
279
+ # Adapted from LlamaAttention.forward
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ position_ids: Optional[torch.LongTensor] = None,
285
+ past_key_value: Optional[Cache] = None,
286
+ output_attentions: bool = False,
287
+ use_cache: bool = False,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
290
+ **kwargs,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ if output_attentions:
293
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
294
+ logger.warning_once(
295
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
296
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
297
+ )
298
+ return super().forward(
299
+ hidden_states=hidden_states,
300
+ attention_mask=attention_mask,
301
+ position_ids=position_ids,
302
+ past_key_value=past_key_value,
303
+ output_attentions=output_attentions,
304
+ use_cache=use_cache,
305
+ cache_position=cache_position,
306
+ position_embeddings=position_embeddings,
307
+ )
308
+
309
+ bsz, q_len, _ = hidden_states.size()
310
+
311
+ query_states = self.q_proj(hidden_states)
312
+ key_states = self.k_proj(hidden_states)
313
+ value_states = self.v_proj(hidden_states)
314
+
315
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
316
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
317
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
318
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
319
+
320
+ query_states = self.q_norm(query_states)
321
+ key_states = self.k_norm(key_states)
322
+
323
+ if position_embeddings is None:
324
+ logger.warning_once(
325
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
326
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
327
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
328
+ "removed and `position_embeddings` will be mandatory."
329
+ )
330
+ cos, sin = self.rotary_emb(value_states, position_ids)
331
+ else:
332
+ cos, sin = position_embeddings
333
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
334
+
335
+ if past_key_value is not None:
336
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
337
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
338
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
339
+
340
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
341
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
342
+
343
+ causal_mask = attention_mask
344
+ if attention_mask is not None:
345
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
346
+
347
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
348
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
349
+ if query_states.device.type == "cuda" and causal_mask is not None:
350
+ query_states = query_states.contiguous()
351
+ key_states = key_states.contiguous()
352
+ value_states = value_states.contiguous()
353
+
354
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
355
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
356
+ is_causal = True if causal_mask is None and q_len > 1 else False
357
+
358
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
359
+ query_states,
360
+ key_states,
361
+ value_states,
362
+ attn_mask=causal_mask,
363
+ dropout_p=self.attention_dropout if self.training else 0.0,
364
+ is_causal=is_causal,
365
+ )
366
+
367
+ attn_output = attn_output.transpose(1, 2).contiguous()
368
+ attn_output = attn_output.view(bsz, q_len, -1)
369
+
370
+ attn_output = self.o_proj(attn_output)
371
+
372
+ return attn_output, None, past_key_value
373
+
374
+
375
+ class LlamaFlexAttention(LlamaAttention):
376
+ """
377
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
378
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
379
+ SDPA API.
380
+ """
381
+
382
+ # Adapted from LlamaAttention.forward
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.LongTensor] = None,
388
+ past_key_value: Optional[Cache] = None,
389
+ output_attentions: bool = False,
390
+ use_cache: bool = False,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
393
+ **kwargs,
394
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
395
+ bsz, q_len, _ = hidden_states.size()
396
+
397
+ query_states = self.q_proj(hidden_states)
398
+ key_states = self.k_proj(hidden_states)
399
+ value_states = self.v_proj(hidden_states)
400
+
401
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
402
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
403
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
404
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
+ dtype = query_states.dtype
406
+
407
+ query_states = self.q_norm(query_states).to(dtype)
408
+ key_states = self.k_norm(key_states).to(dtype)
409
+
410
+ if position_embeddings is None:
411
+ logger.warning_once(
412
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
413
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
414
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
415
+ "removed and `position_embeddings` will be mandatory."
416
+ )
417
+ cos, sin = self.rotary_emb(value_states, position_ids)
418
+ else:
419
+ cos, sin = position_embeddings
420
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
421
+
422
+ if past_key_value is not None:
423
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
424
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
425
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
426
+
427
+
428
+ attn_output, attn_weight = flex_attention_forward(
429
+ query_states,
430
+ key_states,
431
+ value_states,
432
+ attention_mask,
433
+ training=self.training,
434
+ )
435
+ # print(attn_output.shape)
436
+
437
+ attn_output = attn_output.view(bsz, q_len, -1)
438
+ #print(attn_output.shape)
439
+ #print(self.o_proj)
440
+
441
+ attn_output = self.o_proj(attn_output)
442
+
443
+ return attn_output, attn_weight, past_key_value
444
+
445
+
446
+ LLAMA_ATTENTION_CLASSES = {
447
+ "eager": LlamaAttention,
448
+ "flash_attention_2": LlamaFlashAttention2,
449
+ "flex_attention": LlamaFlexAttention,
450
+ "sdpa": LlamaSdpaAttention,
451
+ }
diffrhythm2/backbones/llama_nar.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from transformers import LlamaConfig
17
+ import torch
18
+
19
+ import torch.nn as nn
20
+ from typing import Optional, Tuple
21
+ import math
22
+
23
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
24
+ from .llama_attention import LLAMA_ATTENTION_CLASSES
25
+
26
+ # sinusoidal positional encoding
27
+ class SinusoidalPosEmb(nn.Module):
28
+ def __init__(self, dim):
29
+ super().__init__()
30
+ self.dim = dim
31
+
32
+ def forward(self, x):
33
+ device = x.device
34
+ half_dim = self.dim // 2
35
+ emb = math.log(10000) / (half_dim - 1)
36
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
37
+ emb = x[:, None] * emb[None, :] * 1.0
38
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
39
+ return emb
40
+
41
+
42
+ class LlamaAdaptiveRMSNorm(nn.Module):
43
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
44
+ super().__init__()
45
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
46
+ nn.init.zeros_(self.to_weight.weight)
47
+ nn.init.ones_(self.to_weight.bias)
48
+ self.variance_epsilon = eps
49
+ self._is_hf_initialized = True # disable automatic init
50
+
51
+ def forward(self, hidden_states, cond_embedding):
52
+ input_dtype = hidden_states.dtype
53
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
54
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
55
+
56
+ weight = self.to_weight(cond_embedding)
57
+ if len(weight.shape) == 2:
58
+ weight = weight.unsqueeze(1)
59
+
60
+ return (weight * hidden_states).to(input_dtype)
61
+
62
+
63
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
64
+ def __init__(self, config: LlamaConfig, layer_idx: int, use_flex_attn: bool=False):
65
+ """Override to adaptive layer norm"""
66
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
67
+ _attn_implementation = config._attn_implementation
68
+ if use_flex_attn:
69
+ _attn_implementation = "flex_attention"
70
+ # _attn_implementation = "flash_attention_2"
71
+ self.self_attn = LLAMA_ATTENTION_CLASSES[_attn_implementation](config=config, layer_idx=layer_idx)
72
+ # self.input_layernorm = LlamaAdaptiveRMSNorm(
73
+ # config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
74
+ # )
75
+ # self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
76
+ # config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
77
+ # )
78
+
79
+ # add `cond` in forward function
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ position_embeddings: Optional[torch.LongTensor] = None,
85
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
86
+ output_attentions: Optional[bool] = False,
87
+ use_cache: Optional[bool] = False,
88
+ ) -> Tuple[
89
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
90
+ ]:
91
+ """
92
+ Args:
93
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
94
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
95
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
96
+ output_attentions (`bool`, *optional*):
97
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
98
+ returned tensors for more detail.
99
+ use_cache (`bool`, *optional*):
100
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
101
+ (see `past_key_values`).
102
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
103
+ """
104
+
105
+ residual = hidden_states
106
+ # print(-1, hidden_states.isnan().sum(), hidden_states.isinf().sum())
107
+ hidden_states = self.input_layernorm(
108
+ hidden_states
109
+ )
110
+ # print(0, hidden_states.isnan().sum(), hidden_states.isinf().sum())
111
+ # Self Attention
112
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
113
+ hidden_states=hidden_states,
114
+ attention_mask=attention_mask,
115
+ position_embeddings=position_embeddings,
116
+ past_key_value=past_key_value,
117
+ output_attentions=output_attentions,
118
+ use_cache=use_cache,
119
+ )
120
+ # print(1, hidden_states.isnan().sum(), hidden_states.isinf().sum())
121
+ hidden_states = residual + hidden_states
122
+ # print(2, hidden_states.isnan().sum(), hidden_states.isinf().sum())
123
+ # Fully Connected
124
+ residual = hidden_states
125
+ hidden_states = self.post_attention_layernorm(
126
+ hidden_states
127
+ )
128
+ # print(3, hidden_states.isnan().sum(), hidden_states.isinf().sum())
129
+ hidden_states = self.mlp(hidden_states)
130
+ hidden_states = residual + hidden_states
131
+ # print(4, hidden_states.isnan().sum(), hidden_states.isinf().sum())
132
+ outputs = [hidden_states,]
133
+
134
+ if output_attentions:
135
+ outputs += [self_attn_weights,]
136
+
137
+ if use_cache:
138
+ outputs += [present_key_value,]
139
+
140
+ return outputs
diffrhythm2/cache_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+
18
+ from typing import Optional, List, Tuple, Dict, Any
19
+ from transformers.cache_utils import Cache
20
+ from contextlib import contextmanager
21
+
22
+ class BlockFlowMatchingCache(Cache):
23
+ def __init__(
24
+ self,
25
+ text_lengths: Optional[torch.Tensor] = None,
26
+ block_size: Optional[int] = None,
27
+ num_history_block: Optional[int] = None
28
+ ) -> None:
29
+ super().__init__()
30
+ self._seen_tokens = 0
31
+ self.text_key_cache: List[torch.Tensor] = []
32
+ self.text_value_cache: List[torch.Tensor] = []
33
+ self.key_cache: List[torch.Tensor] = []
34
+ self.value_cache: List[torch.Tensor] = []
35
+ self.text_lengths = text_lengths
36
+ self.block_size = block_size
37
+ self.num_history_block = num_history_block
38
+ self.is_cache_text = False
39
+ self.is_storage_cache = False
40
+ assert (
41
+ (
42
+ self.num_history_block is not None
43
+ and
44
+ self.block_size is not None
45
+ ) or self.num_history_block is None
46
+ ), "num_history_block and block_size must be set at the same time."
47
+
48
+ @contextmanager
49
+ def cache_text(self):
50
+ self.is_cache_text = True
51
+ try:
52
+ yield self
53
+ finally:
54
+ self.is_cache_text = False
55
+
56
+ @contextmanager
57
+ def cache_context(self):
58
+ self.is_storage_cache = True
59
+ try:
60
+ yield self
61
+ finally:
62
+ self.is_storage_cache = False
63
+
64
+ def update(
65
+ self,
66
+ key_states: torch.Tensor,
67
+ value_states: torch.Tensor,
68
+ layer_idx: int,
69
+ cache_kwargs: Optional[Dict[str, Any]] = None,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """
72
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
73
+
74
+ Parameters:
75
+ key_states (`torch.Tensor`):
76
+ The new key states to cache.
77
+ value_states (`torch.Tensor`):
78
+ The new value states to cache.
79
+ layer_idx (`int`):
80
+ The index of the layer to cache the states for.
81
+ cache_kwargs (`Dict[str, Any]`, `optional`):
82
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
83
+
84
+ Return:
85
+ A tuple containing the updated key and value states.
86
+ """
87
+ # cache text
88
+ if self.is_cache_text:
89
+ if self.text_lengths is None:
90
+ self.text_lengths = torch.LongTensor([key_states.shape[-2]] * key_states.shape[0])
91
+ self.text_key_cache.append(key_states)
92
+ self.text_value_cache.append(value_states)
93
+ return self.text_key_cache[layer_idx], self.text_value_cache[layer_idx]
94
+
95
+ # Update the number of seen tokens
96
+ if layer_idx == 0:
97
+ self._seen_tokens += key_states.shape[-2]
98
+
99
+ # Update the cache
100
+ if key_states is not None:
101
+ if len(self.key_cache) <= layer_idx:
102
+ # There may be skipped layers, fill them with empty lists
103
+ for _ in range(len(self.key_cache), layer_idx + 1):
104
+ self.key_cache.append([])
105
+ self.value_cache.append([])
106
+ cached_key_state = self.key_cache[layer_idx]
107
+ cached_value_state = self.value_cache[layer_idx]
108
+ if len(cached_key_state) != 0:
109
+ key_states = torch.cat([cached_key_state, key_states], dim=-2)
110
+ value_states = torch.cat([cached_value_state, value_states], dim=-2)
111
+ if self.num_history_block is not None:
112
+ history_length = self.block_size * (self.num_history_block + 1)
113
+ key_states = key_states[:, :, -history_length:, :]
114
+ value_states = value_states[:, :, -history_length:, :]
115
+ if self.is_storage_cache:
116
+ self.key_cache[layer_idx] = key_states
117
+ self.value_cache[layer_idx] = value_states
118
+
119
+ k_s = []
120
+ v_s = []
121
+
122
+ text_key_cache = (
123
+ self.text_key_cache[layer_idx]
124
+ if len(self.text_key_cache) > layer_idx
125
+ else torch.zeros(key_states.shape[0], key_states.shape[1], 0, key_states.shape[3], device=key_states.device, dtype=key_states.dtype)
126
+ )
127
+ text_value_cache = (
128
+ self.text_value_cache[layer_idx]
129
+ if len(self.text_value_cache) > layer_idx
130
+ else torch.zeros(value_states.shape[0], value_states.shape[1], 0, value_states.shape[3], device=value_states.device, dtype=value_states.dtype)
131
+ )
132
+ for b in range(self.text_lengths.shape[0]):
133
+ k_s.append(torch.cat([text_key_cache[b][:, :self.text_lengths[b], :], key_states[b]], dim=-2))
134
+ v_s.append(torch.cat([text_value_cache[b][:, :self.text_lengths[b], :], value_states[b]], dim=-2))
135
+ k_s = torch.nn.utils.rnn.pad_sequence(k_s, batch_first=True)
136
+ v_s = torch.nn.utils.rnn.pad_sequence(v_s, batch_first=True)
137
+
138
+ return k_s, v_s
139
+
140
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
141
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
142
+ # TODO: deprecate this function in favor of `cache_position`
143
+ is_empty_layer = (
144
+ len(self.key_cache) == 0 # no cache in any layer
145
+ or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
146
+ or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
147
+ )
148
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
149
+ return layer_seq_length
150
+
151
+ def get_max_cache_shape(self) -> Optional[int]:
152
+ """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
153
+ return None
154
+
diffrhythm2/cfm.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+ import torch
17
+ from torch import nn
18
+ from tqdm import tqdm
19
+
20
+ from torchdiffeq import odeint
21
+ from .backbones.dit import DiT
22
+ from .cache_utils import BlockFlowMatchingCache
23
+ from torch.nn.attention.flex_attention import create_block_mask
24
+
25
+ def all_mask(b, h, q_idx, kv_idx):
26
+ return q_idx == q_idx
27
+
28
+
29
+ class CFM(nn.Module):
30
+ def __init__(
31
+ self,
32
+ transformer: DiT,
33
+ sigma=0.0,
34
+ odeint_kwargs: dict = dict(
35
+ # atol = 1e-5,
36
+ # rtol = 1e-5,
37
+ method="euler" # 'midpoint'
38
+ # method="adaptive_heun"
39
+ ),
40
+ odeint_options: dict = dict(
41
+ min_step=0.05
42
+ ),
43
+ num_channels=None,
44
+ block_size=None,
45
+ num_history_block=None
46
+ ):
47
+ super().__init__()
48
+
49
+ self.num_channels = num_channels
50
+
51
+ # transformer
52
+ self.transformer = transformer
53
+ dim = transformer.dim
54
+ self.dim = dim
55
+
56
+ # conditional flow related
57
+ self.sigma = sigma
58
+
59
+ # sampling related
60
+ self.odeint_kwargs = odeint_kwargs
61
+ print(f"ODE SOLVER: {self.odeint_kwargs['method']}")
62
+
63
+ self.odeint_options = odeint_options
64
+ self.block_size = block_size
65
+ self.num_history_block = num_history_block
66
+ if self.num_history_block is not None and self.num_history_block <= 0:
67
+ self.num_history_block = None
68
+
69
+ print(f"block_size: {self.block_size}; num_history_block: {self.num_history_block}")
70
+
71
+ @property
72
+ def device(self):
73
+ return next(self.parameters()).device
74
+
75
+ @torch.no_grad()
76
+ def sample_block_cache(
77
+ self,
78
+ text,
79
+ duration, # noqa: F821
80
+ style_prompt,
81
+ steps=32,
82
+ cfg_strength=1.0,
83
+ odeint_method='euler'
84
+ ):
85
+ self.eval()
86
+
87
+ batch = text.shape[0]
88
+ device = self.device
89
+ num_blocks = duration // self.block_size + (duration % self.block_size > 0)
90
+
91
+ text_emb = self.transformer.text_embed(text)
92
+ cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text))
93
+ text_lens = torch.LongTensor([text_emb.shape[1]]).to(device)
94
+ clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype)
95
+ noisy_lens = torch.LongTensor([self.block_size]).to(device)
96
+ block_iterator = range(num_blocks)
97
+
98
+ # create cache
99
+ kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
100
+ cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
101
+ cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype)
102
+
103
+ # generate text cache
104
+ text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype)
105
+ text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1)
106
+ text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool()
107
+ # text_attn_mask = create_block_mask(
108
+ # all_mask,
109
+ # B = batch,
110
+ # H = None,
111
+ # Q_LEN=text_emb.shape[1],
112
+ # KV_LEN=text_emb.shape[1]
113
+ # )
114
+
115
+ if text_emb.shape[1] != 0:
116
+ with kv_cache.cache_text():
117
+ _, _, kv_cache = self.transformer(
118
+ x = text_emb,
119
+ time=text_time,
120
+ attn_mask=text_attn_mask,
121
+ position_ids=text_position_ids,
122
+ style_prompt=style_prompt,
123
+ use_cache=True,
124
+ past_key_value = kv_cache
125
+ )
126
+ with cfg_kv_cache.cache_text():
127
+ _, _, cfg_kv_cache = self.transformer(
128
+ x = cfg_text_emb,
129
+ time=text_time,
130
+ attn_mask=text_attn_mask,
131
+ position_ids=text_position_ids,
132
+ style_prompt=torch.zeros_like(style_prompt),
133
+ use_cache=True,
134
+ past_key_value = cfg_kv_cache
135
+ )
136
+
137
+ end_pos = 0
138
+ for bid in block_iterator:
139
+ clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device)
140
+ #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True)
141
+
142
+ # all one mask
143
+ attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV]
144
+ # attn_mask = create_block_mask(
145
+ # all_mask,
146
+ # B = batch,
147
+ # H = None,
148
+ # Q_LEN=noisy_lens.max(),
149
+ # KV_LEN=(text_lens + clean_lens + noisy_lens).max()
150
+ # )
151
+
152
+ # generate position id
153
+ position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1)
154
+ position_ids = position_ids[:, -noisy_lens.max():]
155
+
156
+ # core sample fn
157
+ def fn(t, x):
158
+ noisy_embed = self.transformer.latent_embed(x)
159
+
160
+ if t.ndim == 0:
161
+ t = t.repeat(batch)
162
+ time = t[:, None].repeat(1, noisy_lens.max())
163
+
164
+ pred, *_ = self.transformer(
165
+ x=noisy_embed,
166
+ time=time,
167
+ attn_mask=attn_mask,
168
+ position_ids=position_ids,
169
+ style_prompt=style_prompt,
170
+ use_cache=True,
171
+ past_key_value = kv_cache
172
+ )
173
+ if cfg_strength < 1e-5:
174
+ return pred
175
+
176
+ null_pred, *_ = self.transformer(
177
+ x=noisy_embed,
178
+ time=time,
179
+ attn_mask=attn_mask,
180
+ position_ids=position_ids,
181
+ style_prompt=torch.zeros_like(style_prompt),
182
+ use_cache=True,
183
+ past_key_value = cfg_kv_cache
184
+ )
185
+
186
+ return pred + (pred - null_pred) * cfg_strength
187
+
188
+ # generate time
189
+ noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype)
190
+ t_start = 0
191
+ t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype)
192
+
193
+ # sampling
194
+ outputs = odeint(fn, noisy_emb, t_set, method=odeint_method)
195
+ sampled = outputs[-1]
196
+
197
+ # generate next kv cache
198
+ cache_embed = self.transformer.latent_embed(sampled)
199
+ with kv_cache.cache_context():
200
+ _, _, kv_cache = self.transformer(
201
+ x = cache_embed,
202
+ time=cache_time,
203
+ attn_mask=attn_mask,
204
+ position_ids=position_ids,
205
+ style_prompt=style_prompt,
206
+ use_cache=True,
207
+ past_key_value = kv_cache
208
+ )
209
+ with cfg_kv_cache.cache_context():
210
+ _, _, cfg_kv_cache = self.transformer(
211
+ x = cache_embed,
212
+ time=cache_time,
213
+ attn_mask=attn_mask,
214
+ position_ids=position_ids,
215
+ style_prompt=torch.zeros_like(style_prompt),
216
+ use_cache=True,
217
+ past_key_value = cfg_kv_cache
218
+ )
219
+
220
+ # push new block
221
+ clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1)
222
+
223
+ pos = -1
224
+ curr_frame = clean_emb_stream[:, pos, :]
225
+ eos = torch.ones_like(curr_frame)
226
+ last_kl = torch.nn.functional.mse_loss(
227
+ curr_frame,
228
+ eos
229
+ )
230
+ if last_kl.abs() <= 0.05:
231
+ while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]:
232
+ pos -= 1
233
+ curr_frame = clean_emb_stream[:, pos, :]
234
+ last_kl = torch.nn.functional.mse_loss(
235
+ curr_frame,
236
+ eos
237
+ )
238
+ end_pos = clean_emb_stream.shape[1] + pos
239
+ break
240
+ else:
241
+ end_pos = clean_emb_stream.shape[1]
242
+
243
+ clean_emb_stream = clean_emb_stream[:, :end_pos, :]
244
+
245
+ return clean_emb_stream
246
+
247
+ def sample_cache_stream(
248
+ self,
249
+ decoder,
250
+ text,
251
+ duration, # noqa: F821
252
+ style_prompt,
253
+ steps=32,
254
+ cfg_strength=1.0,
255
+ seed: int | None = None,
256
+ chunk_size=10,
257
+ overlap=2,
258
+ odeint_method='euler'
259
+ ):
260
+ self.eval()
261
+
262
+ batch = text.shape[0]
263
+ device = self.device
264
+ num_blocks = duration // self.block_size + (duration % self.block_size > 0)
265
+
266
+ text_emb = self.transformer.text_embed(text)
267
+ cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text))
268
+ text_lens = torch.LongTensor([text_emb.shape[1]]).to(device)
269
+ clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype)
270
+ noisy_lens = torch.LongTensor([self.block_size]).to(device)
271
+ block_iterator = range(num_blocks)
272
+ # create cache
273
+ kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
274
+ cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
275
+ cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype)
276
+
277
+ # generate text cache
278
+ text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype)
279
+ text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1)
280
+ text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool()
281
+
282
+ if text_emb.shape[1] != 0:
283
+ with kv_cache.cache_text():
284
+ _, _, kv_cache = self.transformer(
285
+ x = text_emb,
286
+ time=text_time,
287
+ attn_mask=text_attn_mask,
288
+ position_ids=text_position_ids,
289
+ style_prompt=style_prompt,
290
+ use_cache=True,
291
+ past_key_value = kv_cache
292
+ )
293
+ with cfg_kv_cache.cache_text():
294
+ _, _, cfg_kv_cache = self.transformer(
295
+ x = cfg_text_emb,
296
+ time=text_time,
297
+ attn_mask=text_attn_mask,
298
+ position_ids=text_position_ids,
299
+ style_prompt=torch.zeros_like(style_prompt),
300
+ use_cache=True,
301
+ past_key_value = cfg_kv_cache
302
+ )
303
+
304
+ end_pos = 0
305
+ last_decoder_pos = 0
306
+ decode_audio = []
307
+ for bid in block_iterator:
308
+ clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device)
309
+ #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True)
310
+
311
+ # all one mask
312
+ attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV]
313
+
314
+ # generate position id
315
+ position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1)
316
+ position_ids = position_ids[:, -noisy_lens.max():]
317
+
318
+ # core sample fn
319
+ def fn(t, x):
320
+ noisy_embed = self.transformer.latent_embed(x)
321
+
322
+ if t.ndim == 0:
323
+ t = t.repeat(batch)
324
+ time = t[:, None].repeat(1, noisy_lens.max())
325
+
326
+ pred, *_ = self.transformer(
327
+ x=noisy_embed,
328
+ time=time,
329
+ attn_mask=attn_mask,
330
+ position_ids=position_ids,
331
+ style_prompt=style_prompt,
332
+ use_cache=True,
333
+ past_key_value = kv_cache
334
+ )
335
+ if cfg_strength < 1e-5:
336
+ return pred
337
+
338
+ null_pred, *_ = self.transformer(
339
+ x=noisy_embed,
340
+ time=time,
341
+ attn_mask=attn_mask,
342
+ position_ids=position_ids,
343
+ style_prompt=torch.zeros_like(style_prompt),
344
+ use_cache=True,
345
+ past_key_value = cfg_kv_cache
346
+ )
347
+
348
+ return pred + (pred - null_pred) * cfg_strength
349
+
350
+ # generate time
351
+ noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype)
352
+ t_start = 0
353
+ t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype)
354
+
355
+ # sampling
356
+ outputs = odeint(fn, noisy_emb, t_set, method=odeint_method)
357
+ sampled = outputs[-1]
358
+
359
+ # generate next kv cache
360
+ cache_embed = self.transformer.latent_embed(sampled)
361
+ with kv_cache.cache_context():
362
+ _, _, kv_cache = self.transformer(
363
+ x = cache_embed,
364
+ time=cache_time,
365
+ attn_mask=attn_mask,
366
+ position_ids=position_ids,
367
+ style_prompt=style_prompt,
368
+ use_cache=True,
369
+ past_key_value = kv_cache
370
+ )
371
+ with cfg_kv_cache.cache_context():
372
+ _, _, cfg_kv_cache = self.transformer(
373
+ x = cache_embed,
374
+ time=cache_time,
375
+ attn_mask=attn_mask,
376
+ position_ids=position_ids,
377
+ style_prompt=torch.zeros_like(style_prompt),
378
+ use_cache=True,
379
+ past_key_value = cfg_kv_cache
380
+ )
381
+
382
+ # push new block
383
+ clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1)
384
+
385
+ pos = -1
386
+ curr_frame = clean_emb_stream[:, pos, :]
387
+ eos = torch.ones_like(curr_frame)
388
+ last_kl = torch.nn.functional.mse_loss(
389
+ curr_frame,
390
+ eos
391
+ )
392
+ if last_kl.abs() <= 0.05:
393
+ while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]:
394
+ pos -= 1
395
+ curr_frame = clean_emb_stream[:, pos, :]
396
+ last_kl = torch.nn.functional.mse_loss(
397
+ curr_frame,
398
+ eos
399
+ )
400
+ end_pos = clean_emb_stream.shape[1] + pos
401
+ break
402
+ else:
403
+ end_pos = clean_emb_stream.shape[1]
404
+ if end_pos - last_decoder_pos >= chunk_size:
405
+ start = max(0, last_decoder_pos - overlap)
406
+ overlap_frame = max(0, last_decoder_pos - start)
407
+ latent = clean_emb_stream[:, start:end_pos, :]
408
+ audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T]
409
+ # print(last_decoder_pos, start, end_pos, latent.shape, audio.shape, clean_emb_stream.shape, chunk_size, overlap_frame, last_decoder_pos-overlap, last_decoder_pos-start)
410
+ audio = audio[:, :, overlap_frame * 9600:]
411
+ print(audio.shape)
412
+ yield audio
413
+ last_decoder_pos = end_pos
414
+
415
+ clean_emb_stream = clean_emb_stream[:, :end_pos, :]
416
+ start = max(0, last_decoder_pos - overlap)
417
+ overlap = max(0, last_decoder_pos - start)
418
+ latent = clean_emb_stream[:, start:end_pos, :]
419
+ audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T]
420
+ audio = audio[:, :, overlap * 9600:]
421
+ print("last", audio.shape)
422
+ audio = torch.cat([audio, torch.zeros(audio.shape[0], audio.shape[1], 5, device=audio.device, dtype=audio.dtype)], dim=-1)
423
+ print(audio.shape)
424
+ yield audio
425
+
g2p/__init__.py ADDED
File without changes
g2p/g2p/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from g2p.g2p import cleaners
7
+ from tokenizers import Tokenizer
8
+ from g2p.g2p.text_tokenizers import TextTokenizer
9
+ from g2p.language_segmentation import LangSegment as LS
10
+ import json
11
+ import re
12
+
13
+ LangSegment = LS()
14
+
15
+ class PhonemeBpeTokenizer:
16
+ def __init__(self, vacab_path="./f5_tts/g2p/g2p/vocab.json"):
17
+ self.lang2backend = {
18
+ "zh": "cmn",
19
+ "ja": "ja",
20
+ "en": "en-us",
21
+ "fr": "fr-fr",
22
+ "ko": "ko",
23
+ "de": "de",
24
+ }
25
+ self.text_tokenizers = {}
26
+ self.int_text_tokenizers()
27
+
28
+ with open(vacab_path, "r") as f:
29
+ json_data = f.read()
30
+ data = json.loads(json_data)
31
+ self.vocab = data["vocab"]
32
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
33
+
34
+ def int_text_tokenizers(self):
35
+ for key, value in self.lang2backend.items():
36
+ self.text_tokenizers[key] = TextTokenizer(language=value)
37
+
38
+ def tokenize(self, text, sentence, language):
39
+
40
+ # 1. convert text to phoneme
41
+ phonemes = []
42
+ if language == "auto":
43
+ seglist = LangSegment.getTexts(text)
44
+ tmp_ph = []
45
+ for seg in seglist:
46
+ tmp_ph.append(
47
+ self._clean_text(
48
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
49
+ )
50
+ )
51
+ phonemes = "|_|".join(tmp_ph)
52
+ else:
53
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
54
+ # print('clean text: ', phonemes)
55
+
56
+ # 2. tokenize phonemes
57
+ phoneme_tokens = self.phoneme2token(phonemes)
58
+ # print('encode: ', phoneme_tokens)
59
+
60
+ # # 3. decode tokens [optional]
61
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
62
+ # print('decoded: ', decoded_text)
63
+
64
+ return phonemes, phoneme_tokens
65
+
66
+ def _clean_text(self, text, sentence, language, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(cleaners, name)
69
+ if not cleaner:
70
+ raise Exception("Unknown cleaner: %s" % name)
71
+ text = cleaner(text, sentence, language, self.text_tokenizers)
72
+ return text
73
+
74
+ def phoneme2token(self, phonemes):
75
+ tokens = []
76
+ if isinstance(phonemes, list):
77
+ for phone in phonemes:
78
+ phone = phone.split("\t")[0]
79
+ phonemes_split = phone.split("|")
80
+ tokens.append(
81
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
82
+ )
83
+ else:
84
+ phonemes = phonemes.split("\t")[0]
85
+ phonemes_split = phonemes.split("|")
86
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
87
+ return tokens
g2p/g2p/chinese_model_g2p.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ import json
11
+ from transformers import BertTokenizer
12
+ from torch.utils.data import Dataset
13
+ from transformers.models.bert.modeling_bert import *
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions
17
+
18
+
19
+ class PolyDataset(Dataset):
20
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
21
+ self.dataset = self.preprocess(words, labels)
22
+ self.word_pad_idx = word_pad_idx
23
+ self.label_pad_idx = label_pad_idx
24
+
25
+ def preprocess(self, origin_sentences, origin_labels):
26
+ """
27
+ Maps tokens and tags to their indices and stores them in the dict data.
28
+ examples:
29
+ word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
30
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
31
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
32
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
33
+ """
34
+ data = []
35
+ labels = []
36
+ sentences = []
37
+ # tokenize
38
+ for line in origin_sentences:
39
+ # replace each token by its index
40
+ # we can not use encode_plus because our sentences are aligned to labels in list type
41
+ words = []
42
+ word_lens = []
43
+ for token in line:
44
+ words.append(token)
45
+ word_lens.append(1)
46
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
47
+ sentences.append(((words, token_start_idxs), 0))
48
+ ###
49
+ for tag in origin_labels:
50
+ labels.append(tag)
51
+
52
+ for sentence, label in zip(sentences, labels):
53
+ data.append((sentence, label))
54
+ return data
55
+
56
+ def __getitem__(self, idx):
57
+ """sample data to get batch"""
58
+ word = self.dataset[idx][0]
59
+ label = self.dataset[idx][1]
60
+ return [word, label]
61
+
62
+ def __len__(self):
63
+ """get dataset size"""
64
+ return len(self.dataset)
65
+
66
+ def collate_fn(self, batch):
67
+
68
+ sentences = [x[0][0] for x in batch]
69
+ ori_sents = [x[0][1] for x in batch]
70
+ labels = [x[1] for x in batch]
71
+ batch_len = len(sentences)
72
+
73
+ # compute length of longest sentence in batch
74
+ max_len = max([len(s[0]) for s in sentences])
75
+ max_label_len = 0
76
+ batch_data = np.ones((batch_len, max_len))
77
+ batch_label_starts = []
78
+
79
+ # padding and aligning
80
+ for j in range(batch_len):
81
+ cur_len = len(sentences[j][0])
82
+ batch_data[j][:cur_len] = sentences[j][0]
83
+ label_start_idx = sentences[j][-1]
84
+ label_starts = np.zeros(max_len)
85
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
86
+ batch_label_starts.append(label_starts)
87
+ max_label_len = max(int(sum(label_starts)), max_label_len)
88
+
89
+ # padding label
90
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
91
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
92
+ for j in range(batch_len):
93
+ cur_tags_len = len(labels[j])
94
+ batch_labels[j][:cur_tags_len] = labels[j]
95
+ batch_pmasks[j][:cur_tags_len] = [
96
+ 1 if item > 0 else 0 for item in labels[j]
97
+ ]
98
+
99
+ # convert data to torch LongTensors
100
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
101
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
102
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
103
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
104
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
105
+
106
+
107
+ class BertPolyPredict:
108
+ def __init__(self, bert_model, jsonr_file, json_file):
109
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
110
+ with open(jsonr_file, "r", encoding="utf8") as fp:
111
+ self.pron_dict = json.load(fp)
112
+ with open(json_file, "r", encoding="utf8") as fp:
113
+ self.pron_dict_id_2_pinyin = json.load(fp)
114
+ self.num_polyphone = len(self.pron_dict)
115
+ self.device = "cpu"
116
+ self.polydataset = PolyDataset
117
+ options = SessionOptions() # initialize session options
118
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
119
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
120
+ self.session = InferenceSession(
121
+ os.path.join(bert_model, "poly_bert_model.onnx"),
122
+ sess_options=options,
123
+ providers=[
124
+ "CUDAExecutionProvider",
125
+ "CPUExecutionProvider",
126
+ ], # CPUExecutionProvider #CUDAExecutionProvider
127
+ )
128
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
129
+
130
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
131
+ self.session.disable_fallback()
132
+
133
+ def predict_process(self, txt_list):
134
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
135
+ data = self.polydataset(word_test, label_test)
136
+ predict_loader = DataLoader(
137
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
138
+ )
139
+ pred_tags = self.predict_onnx(predict_loader)
140
+ return pred_tags
141
+
142
+ def predict_onnx(self, dev_loader):
143
+ pred_tags = []
144
+ with torch.no_grad():
145
+ for idx, batch_samples in enumerate(dev_loader):
146
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
147
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
148
+ batch_samples
149
+ )
150
+ # shift tensors to GPU if available
151
+ batch_data = batch_data.to(self.device)
152
+ batch_label_starts = batch_label_starts.to(self.device)
153
+ batch_labels = batch_labels.to(self.device)
154
+ batch_pmasks = batch_pmasks.to(self.device)
155
+ batch_data = np.asarray(batch_data, dtype=np.int32)
156
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
157
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
158
+ batch_output = self.session.run(
159
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
160
+ )[0]
161
+ label_masks = batch_pmasks == 1
162
+ batch_labels = batch_labels.to("cpu").numpy()
163
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
164
+ for j, idx in enumerate(indices):
165
+ if label_masks[i][j]:
166
+ # pred_tag.append(idx)
167
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
168
+ return pred_tags
169
+
170
+ def get_examples_po(self, text_list):
171
+
172
+ word_list = []
173
+ label_list = []
174
+ sentence_list = []
175
+ id = 0
176
+ for line in [text_list]:
177
+ sentence = line[0]
178
+ words = []
179
+ tokens = line[0]
180
+ index = line[-1]
181
+ front = index
182
+ back = len(tokens) - index - 1
183
+ labels = [0] * front + [1] + [0] * back
184
+ words = ["[CLS]"] + [item for item in sentence]
185
+ words = self.tokenizer.convert_tokens_to_ids(words)
186
+ word_list.append(words)
187
+ label_list.append(labels)
188
+ sentence_list.append(sentence)
189
+
190
+ id += 1
191
+ # mask_list.append(masks)
192
+ assert len(labels) + 1 == len(words), print(
193
+ (
194
+ poly,
195
+ sentence,
196
+ words,
197
+ labels,
198
+ sentence,
199
+ len(sentence),
200
+ len(words),
201
+ len(labels),
202
+ )
203
+ )
204
+ assert len(labels) + 1 == len(
205
+ words
206
+ ), "Number of labels does not match number of words"
207
+ assert len(labels) == len(
208
+ sentence
209
+ ), "Number of labels does not match number of sentences"
210
+ assert len(word_list) == len(
211
+ label_list
212
+ ), "Number of label sentences does not match number of word sentences"
213
+ return word_list, label_list, text_list
g2p/g2p/cleaners.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from g2p.g2p.japanese import japanese_to_ipa
8
+ from g2p.g2p.mandarin import chinese_to_ipa
9
+ from g2p.g2p.english import english_to_ipa
10
+ from g2p.g2p.french import french_to_ipa
11
+ from g2p.g2p.korean import korean_to_ipa
12
+ from g2p.g2p.german import german_to_ipa
13
+
14
+
15
+ def cjekfd_cleaners(text, sentence, language, text_tokenizers):
16
+
17
+ if language == "zh":
18
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
19
+ elif language == "ja":
20
+ return japanese_to_ipa(text, text_tokenizers["ja"])
21
+ elif language == "en":
22
+ return english_to_ipa(text, text_tokenizers["en"])
23
+ elif language == "fr":
24
+ return french_to_ipa(text, text_tokenizers["fr"])
25
+ elif language == "ko":
26
+ return korean_to_ipa(text, text_tokenizers["ko"])
27
+ elif language == "de":
28
+ return german_to_ipa(text, text_tokenizers["de"])
29
+ else:
30
+ raise Exception("Unknown language: %s" % language)
31
+ return None
g2p/g2p/english.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from unidecode import unidecode
8
+ import inflect
9
+
10
+ """
11
+ Text clean time
12
+ """
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
15
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
16
+ _percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
17
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
18
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
19
+ _fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
20
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
+ _number_re = re.compile(r"[0-9]+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ("etc", "et cetera"),
46
+ ("btw", "by the way"),
47
+ ]
48
+ ]
49
+
50
+ _special_map = [
51
+ ("t|ɹ", "tɹ"),
52
+ ("d|ɹ", "dɹ"),
53
+ ("t|s", "ts"),
54
+ ("d|z", "dz"),
55
+ ("ɪ|ɹ", "ɪɹ"),
56
+ ("ɐ", "ɚ"),
57
+ ("ᵻ", "ɪ"),
58
+ ("əl", "l"),
59
+ ("x", "k"),
60
+ ("ɬ", "l"),
61
+ ("ʔ", "t"),
62
+ ("n̩", "n"),
63
+ ("oː|ɹ", "oːɹ"),
64
+ ]
65
+
66
+
67
+ def expand_abbreviations(text):
68
+ for regex, replacement in _abbreviations:
69
+ text = re.sub(regex, replacement, text)
70
+ return text
71
+
72
+
73
+ def _remove_commas(m):
74
+ return m.group(1).replace(",", "")
75
+
76
+
77
+ def _expand_decimal_point(m):
78
+ return m.group(1).replace(".", " point ")
79
+
80
+
81
+ def _expand_percent(m):
82
+ return m.group(1).replace("%", " percent ")
83
+
84
+
85
+ def _expand_dollars(m):
86
+ match = m.group(1)
87
+ parts = match.split(".")
88
+ if len(parts) > 2:
89
+ return " " + match + " dollars " # Unexpected format
90
+ dollars = int(parts[0]) if parts[0] else 0
91
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
92
+ if dollars and cents:
93
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
94
+ cent_unit = "cent" if cents == 1 else "cents"
95
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
96
+ elif dollars:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ return " %s %s " % (dollars, dollar_unit)
99
+ elif cents:
100
+ cent_unit = "cent" if cents == 1 else "cents"
101
+ return " %s %s " % (cents, cent_unit)
102
+ else:
103
+ return " zero dollars "
104
+
105
+
106
+ def fraction_to_words(numerator, denominator):
107
+ if numerator == 1 and denominator == 2:
108
+ return " one half "
109
+ if numerator == 1 and denominator == 4:
110
+ return " one quarter "
111
+ if denominator == 2:
112
+ return " " + _inflect.number_to_words(numerator) + " halves "
113
+ if denominator == 4:
114
+ return " " + _inflect.number_to_words(numerator) + " quarters "
115
+ return (
116
+ " "
117
+ + _inflect.number_to_words(numerator)
118
+ + " "
119
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
120
+ + " "
121
+ )
122
+
123
+
124
+ def _expand_fraction(m):
125
+ numerator = int(m.group(1))
126
+ denominator = int(m.group(2))
127
+ return fraction_to_words(numerator, denominator)
128
+
129
+
130
+ def _expand_ordinal(m):
131
+ return " " + _inflect.number_to_words(m.group(0)) + " "
132
+
133
+
134
+ def _expand_number(m):
135
+ num = int(m.group(0))
136
+ if num > 1000 and num < 3000:
137
+ if num == 2000:
138
+ return " two thousand "
139
+ elif num > 2000 and num < 2010:
140
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
141
+ elif num % 100 == 0:
142
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
143
+ else:
144
+ return (
145
+ " "
146
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
147
+ ", ", " "
148
+ )
149
+ + " "
150
+ )
151
+ else:
152
+ return " " + _inflect.number_to_words(num, andword="") + " "
153
+
154
+
155
+ # Normalize numbers pronunciation
156
+ def normalize_numbers(text):
157
+ text = re.sub(_comma_number_re, _remove_commas, text)
158
+ text = re.sub(_pounds_re, r"\1 pounds", text)
159
+ text = re.sub(_dollars_re, _expand_dollars, text)
160
+ text = re.sub(_fraction_re, _expand_fraction, text)
161
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
162
+ text = re.sub(_percent_number_re, _expand_percent, text)
163
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
164
+ text = re.sub(_number_re, _expand_number, text)
165
+ return text
166
+
167
+
168
+ def _english_to_ipa(text):
169
+ # text = unidecode(text).lower()
170
+ text = expand_abbreviations(text)
171
+ text = normalize_numbers(text)
172
+ return text
173
+
174
+
175
+ # special map
176
+ def special_map(text):
177
+ for regex, replacement in _special_map:
178
+ regex = regex.replace("|", "\|")
179
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
180
+ text = re.sub(
181
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
182
+ )
183
+ # text = re.sub(r'([,.!?])', r'|\1', text)
184
+ return text
185
+
186
+
187
+ # Add some special operation
188
+ def english_to_ipa(text, text_tokenizer):
189
+ if type(text) == str:
190
+ text = _english_to_ipa(text)
191
+ else:
192
+ text = [_english_to_ipa(t) for t in text]
193
+ phonemes = text_tokenizer(text)
194
+ if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
195
+ phonemes += "|_"
196
+ if type(text) == str:
197
+ return special_map(phonemes)
198
+ else:
199
+ result_ph = []
200
+ for phone in phonemes:
201
+ result_ph.append(special_map(phone))
202
+ return result_ph
g2p/g2p/french.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ # List of (regular expression, replacement) pairs for abbreviations in french:
12
+ _abbreviations = [
13
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
14
+ for x in [
15
+ ("M", "monsieur"),
16
+ ("Mlle", "mademoiselle"),
17
+ ("Mlles", "mesdemoiselles"),
18
+ ("Mme", "Madame"),
19
+ ("Mmes", "Mesdames"),
20
+ ("N.B", "nota bene"),
21
+ ("M", "monsieur"),
22
+ ("p.c.q", "parce que"),
23
+ ("Pr", "professeur"),
24
+ ("qqch", "quelque chose"),
25
+ ("rdv", "rendez-vous"),
26
+ ("max", "maximum"),
27
+ ("min", "minimum"),
28
+ ("no", "numéro"),
29
+ ("adr", "adresse"),
30
+ ("dr", "docteur"),
31
+ ("st", "saint"),
32
+ ("co", "companie"),
33
+ ("jr", "junior"),
34
+ ("sgt", "sergent"),
35
+ ("capt", "capitain"),
36
+ ("col", "colonel"),
37
+ ("av", "avenue"),
38
+ ("av. J.-C", "avant Jésus-Christ"),
39
+ ("apr. J.-C", "après Jésus-Christ"),
40
+ ("art", "article"),
41
+ ("boul", "boulevard"),
42
+ ("c.-à-d", "c’est-à-dire"),
43
+ ("etc", "et cetera"),
44
+ ("ex", "exemple"),
45
+ ("excl", "exclusivement"),
46
+ ("boul", "boulevard"),
47
+ ]
48
+ ] + [
49
+ (re.compile("\\b%s" % x[0]), x[1])
50
+ for x in [
51
+ ("Mlle", "mademoiselle"),
52
+ ("Mlles", "mesdemoiselles"),
53
+ ("Mme", "Madame"),
54
+ ("Mmes", "Mesdames"),
55
+ ]
56
+ ]
57
+
58
+ rep_map = {
59
+ ":": ",",
60
+ ";": ",",
61
+ ",": ",",
62
+ "。": ".",
63
+ "!": "!",
64
+ "?": "?",
65
+ "\n": ".",
66
+ "·": ",",
67
+ "、": ",",
68
+ "...": ".",
69
+ "…": ".",
70
+ "$": ".",
71
+ "“": "",
72
+ "”": "",
73
+ "‘": "",
74
+ "’": "",
75
+ "(": "",
76
+ ")": "",
77
+ "(": "",
78
+ ")": "",
79
+ "《": "",
80
+ "》": "",
81
+ "【": "",
82
+ "】": "",
83
+ "[": "",
84
+ "]": "",
85
+ "—": "",
86
+ "~": "-",
87
+ "~": "-",
88
+ "「": "",
89
+ "」": "",
90
+ "¿": "",
91
+ "¡": "",
92
+ }
93
+
94
+
95
+ def collapse_whitespace(text):
96
+ # Regular expression matching whitespace:
97
+ _whitespace_re = re.compile(r"\s+")
98
+ return re.sub(_whitespace_re, " ", text).strip()
99
+
100
+
101
+ def remove_punctuation_at_begin(text):
102
+ return re.sub(r"^[,.!?]+", "", text)
103
+
104
+
105
+ def remove_aux_symbols(text):
106
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
107
+ return text
108
+
109
+
110
+ def replace_symbols(text):
111
+ text = text.replace(";", ",")
112
+ text = text.replace("-", " ")
113
+ text = text.replace(":", ",")
114
+ text = text.replace("&", " et ")
115
+ return text
116
+
117
+
118
+ def expand_abbreviations(text):
119
+ for regex, replacement in _abbreviations:
120
+ text = re.sub(regex, replacement, text)
121
+ return text
122
+
123
+
124
+ def replace_punctuation(text):
125
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
126
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
127
+ return replaced_text
128
+
129
+
130
+ def text_normalize(text):
131
+ text = expand_abbreviations(text)
132
+ text = replace_punctuation(text)
133
+ text = replace_symbols(text)
134
+ text = remove_aux_symbols(text)
135
+ text = remove_punctuation_at_begin(text)
136
+ text = collapse_whitespace(text)
137
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
138
+ return text
139
+
140
+
141
+ def french_to_ipa(text, text_tokenizer):
142
+ if type(text) == str:
143
+ text = text_normalize(text)
144
+ phonemes = text_tokenizer(text)
145
+ return phonemes
146
+ else:
147
+ for i, t in enumerate(text):
148
+ text[i] = text_normalize(t)
149
+ return text_tokenizer(text)
g2p/g2p/german.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ rep_map = {
12
+ ":": ",",
13
+ ";": ",",
14
+ ",": ",",
15
+ "。": ".",
16
+ "!": "!",
17
+ "?": "?",
18
+ "\n": ".",
19
+ "·": ",",
20
+ "、": ",",
21
+ "...": ".",
22
+ "…": ".",
23
+ "$": ".",
24
+ "“": "",
25
+ "”": "",
26
+ "‘": "",
27
+ "’": "",
28
+ "(": "",
29
+ ")": "",
30
+ "(": "",
31
+ ")": "",
32
+ "《": "",
33
+ "》": "",
34
+ "【": "",
35
+ "】": "",
36
+ "[": "",
37
+ "]": "",
38
+ "—": "",
39
+ "~": "-",
40
+ "~": "-",
41
+ "「": "",
42
+ "」": "",
43
+ "¿": "",
44
+ "¡": "",
45
+ }
46
+
47
+
48
+ def collapse_whitespace(text):
49
+ # Regular expression matching whitespace:
50
+ _whitespace_re = re.compile(r"\s+")
51
+ return re.sub(_whitespace_re, " ", text).strip()
52
+
53
+
54
+ def remove_punctuation_at_begin(text):
55
+ return re.sub(r"^[,.!?]+", "", text)
56
+
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text):
64
+ text = text.replace(";", ",")
65
+ text = text.replace("-", " ")
66
+ text = text.replace(":", ",")
67
+ return text
68
+
69
+
70
+ def replace_punctuation(text):
71
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
72
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
73
+ return replaced_text
74
+
75
+
76
+ def text_normalize(text):
77
+ text = replace_punctuation(text)
78
+ text = replace_symbols(text)
79
+ text = remove_aux_symbols(text)
80
+ text = remove_punctuation_at_begin(text)
81
+ text = collapse_whitespace(text)
82
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
83
+ return text
84
+
85
+
86
+ def german_to_ipa(text, text_tokenizer):
87
+ if type(text) == str:
88
+ text = text_normalize(text)
89
+ phonemes = text_tokenizer(text)
90
+ return phonemes
91
+ else:
92
+ for i, t in enumerate(text):
93
+ text[i] = text_normalize(t)
94
+ return text_tokenizer(text)
g2p/g2p/japanese.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io, re, os, sys, time, argparse, pdb, json
7
+ from io import StringIO
8
+ from typing import Optional
9
+ import numpy as np
10
+ import traceback
11
+ import pyopenjtalk
12
+ from pykakasi import kakasi
13
+
14
+ punctuation = [",", ".", "!", "?", ":", ";", "'", "…"]
15
+
16
+ jp_xphone2ipa = [
17
+ " a a",
18
+ " i i",
19
+ " u ɯ",
20
+ " e e",
21
+ " o o",
22
+ " a: aː",
23
+ " i: iː",
24
+ " u: ɯː",
25
+ " e: eː",
26
+ " o: oː",
27
+ " k k",
28
+ " s s",
29
+ " t t",
30
+ " n n",
31
+ " h ç",
32
+ " f ɸ",
33
+ " m m",
34
+ " y j",
35
+ " r ɾ",
36
+ " w ɰᵝ",
37
+ " N ɴ",
38
+ " g g",
39
+ " j d ʑ",
40
+ " z z",
41
+ " d d",
42
+ " b b",
43
+ " p p",
44
+ " q q",
45
+ " v v",
46
+ " : :",
47
+ " by b j",
48
+ " ch t ɕ",
49
+ " dy d e j",
50
+ " ty t e j",
51
+ " gy g j",
52
+ " gw g ɯ",
53
+ " hy ç j",
54
+ " ky k j",
55
+ " kw k ɯ",
56
+ " my m j",
57
+ " ny n j",
58
+ " py p j",
59
+ " ry ɾ j",
60
+ " sh ɕ",
61
+ " ts t s ɯ",
62
+ ]
63
+
64
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
65
+ ("ヴォ", "v", "o"),
66
+ ("ヴェ", "v", "e"),
67
+ ("ヴィ", "v", "i"),
68
+ ("ヴァ", "v", "a"),
69
+ ("ヴ", "v", "u"),
70
+ ("ン", None, "N"),
71
+ ("ワ", "w", "a"),
72
+ ("ロ", "r", "o"),
73
+ ("レ", "r", "e"),
74
+ ("ル", "r", "u"),
75
+ ("リョ", "ry", "o"),
76
+ ("リュ", "ry", "u"),
77
+ ("リャ", "ry", "a"),
78
+ ("リェ", "ry", "e"),
79
+ ("リ", "r", "i"),
80
+ ("ラ", "r", "a"),
81
+ ("ヨ", "y", "o"),
82
+ ("ユ", "y", "u"),
83
+ ("ヤ", "y", "a"),
84
+ ("モ", "m", "o"),
85
+ ("メ", "m", "e"),
86
+ ("ム", "m", "u"),
87
+ ("ミョ", "my", "o"),
88
+ ("ミュ", "my", "u"),
89
+ ("ミャ", "my", "a"),
90
+ ("ミェ", "my", "e"),
91
+ ("ミ", "m", "i"),
92
+ ("マ", "m", "a"),
93
+ ("ポ", "p", "o"),
94
+ ("ボ", "b", "o"),
95
+ ("ホ", "h", "o"),
96
+ ("ペ", "p", "e"),
97
+ ("ベ", "b", "e"),
98
+ ("ヘ", "h", "e"),
99
+ ("プ", "p", "u"),
100
+ ("ブ", "b", "u"),
101
+ ("フォ", "f", "o"),
102
+ ("フェ", "f", "e"),
103
+ ("フィ", "f", "i"),
104
+ ("ファ", "f", "a"),
105
+ ("フ", "f", "u"),
106
+ ("ピョ", "py", "o"),
107
+ ("ピュ", "py", "u"),
108
+ ("ピャ", "py", "a"),
109
+ ("ピェ", "py", "e"),
110
+ ("ピ", "p", "i"),
111
+ ("ビョ", "by", "o"),
112
+ ("ビュ", "by", "u"),
113
+ ("ビャ", "by", "a"),
114
+ ("ビェ", "by", "e"),
115
+ ("ビ", "b", "i"),
116
+ ("ヒョ", "hy", "o"),
117
+ ("ヒュ", "hy", "u"),
118
+ ("ヒャ", "hy", "a"),
119
+ ("ヒェ", "hy", "e"),
120
+ ("ヒ", "h", "i"),
121
+ ("パ", "p", "a"),
122
+ ("バ", "b", "a"),
123
+ ("ハ", "h", "a"),
124
+ ("ノ", "n", "o"),
125
+ ("ネ", "n", "e"),
126
+ ("ヌ", "n", "u"),
127
+ ("ニョ", "ny", "o"),
128
+ ("ニュ", "ny", "u"),
129
+ ("ニャ", "ny", "a"),
130
+ ("ニェ", "ny", "e"),
131
+ ("ニ", "n", "i"),
132
+ ("ナ", "n", "a"),
133
+ ("ドゥ", "d", "u"),
134
+ ("ド", "d", "o"),
135
+ ("トゥ", "t", "u"),
136
+ ("ト", "t", "o"),
137
+ ("デョ", "dy", "o"),
138
+ ("デュ", "dy", "u"),
139
+ ("デャ", "dy", "a"),
140
+ # ("デェ", "dy", "e"),
141
+ ("ディ", "d", "i"),
142
+ ("デ", "d", "e"),
143
+ ("テョ", "ty", "o"),
144
+ ("テュ", "ty", "u"),
145
+ ("テャ", "ty", "a"),
146
+ ("ティ", "t", "i"),
147
+ ("テ", "t", "e"),
148
+ ("ツォ", "ts", "o"),
149
+ ("ツェ", "ts", "e"),
150
+ ("ツィ", "ts", "i"),
151
+ ("ツァ", "ts", "a"),
152
+ ("ツ", "ts", "u"),
153
+ ("ッ", None, "q"), # 「cl」から「q」に変更
154
+ ("チョ", "ch", "o"),
155
+ ("チュ", "ch", "u"),
156
+ ("チャ", "ch", "a"),
157
+ ("チェ", "ch", "e"),
158
+ ("チ", "ch", "i"),
159
+ ("ダ", "d", "a"),
160
+ ("タ", "t", "a"),
161
+ ("ゾ", "z", "o"),
162
+ ("ソ", "s", "o"),
163
+ ("ゼ", "z", "e"),
164
+ ("セ", "s", "e"),
165
+ ("ズィ", "z", "i"),
166
+ ("ズ", "z", "u"),
167
+ ("スィ", "s", "i"),
168
+ ("ス", "s", "u"),
169
+ ("ジョ", "j", "o"),
170
+ ("ジュ", "j", "u"),
171
+ ("ジャ", "j", "a"),
172
+ ("ジェ", "j", "e"),
173
+ ("ジ", "j", "i"),
174
+ ("ショ", "sh", "o"),
175
+ ("シュ", "sh", "u"),
176
+ ("シャ", "sh", "a"),
177
+ ("シェ", "sh", "e"),
178
+ ("シ", "sh", "i"),
179
+ ("ザ", "z", "a"),
180
+ ("サ", "s", "a"),
181
+ ("ゴ", "g", "o"),
182
+ ("コ", "k", "o"),
183
+ ("ゲ", "g", "e"),
184
+ ("ケ", "k", "e"),
185
+ ("グヮ", "gw", "a"),
186
+ ("グ", "g", "u"),
187
+ ("クヮ", "kw", "a"),
188
+ ("ク", "k", "u"),
189
+ ("ギョ", "gy", "o"),
190
+ ("ギュ", "gy", "u"),
191
+ ("ギャ", "gy", "a"),
192
+ ("ギェ", "gy", "e"),
193
+ ("ギ", "g", "i"),
194
+ ("キョ", "ky", "o"),
195
+ ("キュ", "ky", "u"),
196
+ ("キャ", "ky", "a"),
197
+ ("キェ", "ky", "e"),
198
+ ("キ", "k", "i"),
199
+ ("ガ", "g", "a"),
200
+ ("カ", "k", "a"),
201
+ ("オ", None, "o"),
202
+ ("エ", None, "e"),
203
+ ("ウォ", "w", "o"),
204
+ ("ウェ", "w", "e"),
205
+ ("ウィ", "w", "i"),
206
+ ("ウ", None, "u"),
207
+ ("イェ", "y", "e"),
208
+ ("イ", None, "i"),
209
+ ("ア", None, "a"),
210
+ ]
211
+
212
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
213
+ ("ヴョ", "by", "o"),
214
+ ("ヴュ", "by", "u"),
215
+ ("ヴャ", "by", "a"),
216
+ ("ヲ", None, "o"),
217
+ ("ヱ", None, "e"),
218
+ ("ヰ", None, "i"),
219
+ ("ヮ", "w", "a"),
220
+ ("ョ", "y", "o"),
221
+ ("ュ", "y", "u"),
222
+ ("ヅ", "z", "u"),
223
+ ("ヂ", "j", "i"),
224
+ ("ヶ", "k", "e"),
225
+ ("ャ", "y", "a"),
226
+ ("ォ", None, "o"),
227
+ ("ェ", None, "e"),
228
+ ("ゥ", None, "u"),
229
+ ("ィ", None, "i"),
230
+ ("ァ", None, "a"),
231
+ ]
232
+
233
+ # 例: "vo" -> "ヴォ", "a" -> "ア"
234
+ mora_phonemes_to_mora_kata: dict[str, str] = {
235
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
236
+ }
237
+
238
+ # 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
239
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
240
+ kana: (consonant, vowel)
241
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
242
+ }
243
+
244
+
245
+ # 正規化で記号を変換するための辞書
246
+ rep_map = {
247
+ ":": ":",
248
+ ";": ";",
249
+ ",": ",",
250
+ "。": ".",
251
+ "!": "!",
252
+ "?": "?",
253
+ "\n": ".",
254
+ ".": ".",
255
+ "⋯": "…",
256
+ "···": "…",
257
+ "・・・": "…",
258
+ "·": ",",
259
+ "・": ",",
260
+ "•": ",",
261
+ "、": ",",
262
+ "$": ".",
263
+ # "“": "'",
264
+ # "”": "'",
265
+ # '"': "'",
266
+ "‘": "'",
267
+ "’": "'",
268
+ # "(": "'",
269
+ # ")": "'",
270
+ # "(": "'",
271
+ # ")": "'",
272
+ # "《": "'",
273
+ # "》": "'",
274
+ # "【": "'",
275
+ # "】": "'",
276
+ # "[": "'",
277
+ # "]": "'",
278
+ # "——": "-",
279
+ # "−": "-",
280
+ # "-": "-",
281
+ # "『": "'",
282
+ # "』": "'",
283
+ # "〈": "'",
284
+ # "〉": "'",
285
+ # "«": "'",
286
+ # "»": "'",
287
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
288
+ # # "~": "-", # これは長音記号「ー」として扱うよう変更
289
+ # "「": "'",
290
+ # "」": "'",
291
+ }
292
+
293
+
294
+ def _numeric_feature_by_regex(regex, s):
295
+ match = re.search(regex, s)
296
+ if match is None:
297
+ return -50
298
+ return int(match.group(1))
299
+
300
+
301
+ def replace_punctuation(text: str) -> str:
302
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
303
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
304
+ """
305
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
306
+ # print("before: ", text)
307
+ # 句読点を辞書で置換
308
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
309
+
310
+ replaced_text = re.sub(
311
+ # ↓ ひらがな、カタカナ、漢字
312
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
313
+ # ↓ 半角アルファベット(大文字と小文字)
314
+ + r"\u0041-\u005A\u0061-\u007A"
315
+ # ↓ 全角アルファベット(大文字と小文字)
316
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
317
+ # ↓ ギリシャ文字
318
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
319
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
320
+ + "".join(punctuation) + r"]+",
321
+ # 上述以外の文字を削除
322
+ "",
323
+ replaced_text,
324
+ )
325
+ # print("after: ", replaced_text)
326
+ return replaced_text
327
+
328
+
329
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
330
+ """
331
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
332
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
333
+ """
334
+ tone_values = set(tone for _, tone in phone_tone_list)
335
+ if len(tone_values) == 1:
336
+ assert tone_values == {0}, tone_values
337
+ return phone_tone_list
338
+ elif len(tone_values) == 2:
339
+ if tone_values == {0, 1}:
340
+ return phone_tone_list
341
+ elif tone_values == {-1, 0}:
342
+ return [
343
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
344
+ ]
345
+ else:
346
+ raise ValueError(f"Unexpected tone values: {tone_values}")
347
+ else:
348
+ raise ValueError(f"Unexpected tone values: {tone_values}")
349
+
350
+
351
+ def fix_phone_tone_wplen(phone_tone_list, word_phone_length_list):
352
+ phones = []
353
+ tones = []
354
+ w_p_len = []
355
+ p_len = len(phone_tone_list)
356
+ idx = 0
357
+ w_idx = 0
358
+ while idx < p_len:
359
+ offset = 0
360
+ if phone_tone_list[idx] == "▁":
361
+ w_p_len.append(w_idx + 1)
362
+
363
+ curr_w_p_len = word_phone_length_list[w_idx]
364
+ for i in range(curr_w_p_len):
365
+ p, t = phone_tone_list[idx]
366
+ if p == ":" and len(phones) > 0:
367
+ if phones[-1][-1] != ":":
368
+ phones[-1] += ":"
369
+ offset -= 1
370
+ else:
371
+ phones.append(p)
372
+ tones.append(str(t))
373
+ idx += 1
374
+ if idx >= p_len:
375
+ break
376
+ w_p_len.append(curr_w_p_len + offset)
377
+ w_idx += 1
378
+ # print(w_p_len)
379
+ return phones, tones, w_p_len
380
+
381
+
382
+ def g2phone_tone_wo_punct(prosodies) -> list[tuple[str, int]]:
383
+ """
384
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
385
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
386
+ 非音素記号を含める処理は`align_tones()`で行われる。
387
+ また「っ」は「cl」でなく「q」に変換される(「ん」は「N」のまま)。
388
+ 例: "こんにちは、世界ー。。元気?!" →
389
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
390
+ """
391
+ result: list[tuple[str, int]] = []
392
+ current_phrase: list[tuple[str, int]] = []
393
+ current_tone = 0
394
+ last_accent = ""
395
+ for i, letter in enumerate(prosodies):
396
+ # 特殊記号の処理
397
+
398
+ # 文頭記号、無視する
399
+ if letter == "^":
400
+ assert i == 0, "Unexpected ^"
401
+ # アクセント句の終わりに来る記号
402
+ elif letter in ("$", "?", "_", "#"):
403
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
404
+ result.extend(fix_phone_tone(current_phrase))
405
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
406
+ if letter in ("$", "?"):
407
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
408
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
409
+ # これらは残さず、次のアクセント句に備える。
410
+
411
+ current_phrase = []
412
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
413
+ current_tone = 0
414
+ last_accent = ""
415
+ # アクセント上昇記号
416
+ elif letter == "[":
417
+ if last_accent != letter:
418
+ current_tone = current_tone + 1
419
+ last_accent = letter
420
+ # アクセント下降記号
421
+ elif letter == "]":
422
+ if last_accent != letter:
423
+ current_tone = current_tone - 1
424
+ last_accent = letter
425
+ # それ以外は通常の音素
426
+ else:
427
+ if letter == "cl": # 「っ」の処理
428
+ letter = "q"
429
+ current_phrase.append((letter, current_tone))
430
+ return result
431
+
432
+
433
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
434
+ for i in range(len(sep_phonemes)):
435
+ if sep_phonemes[i][0] == "ー":
436
+ # sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
437
+ sep_phonemes[i][0] = ":"
438
+ if "ー" in sep_phonemes[i]:
439
+ for j in range(len(sep_phonemes[i])):
440
+ if sep_phonemes[i][j] == "ー":
441
+ # sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
442
+ sep_phonemes[i][j] = ":"
443
+ return sep_phonemes
444
+
445
+
446
+ def handle_long_word(sep_phonemes: list[list[str]]) -> list[list[str]]:
447
+ res = []
448
+ for i in range(len(sep_phonemes)):
449
+ if sep_phonemes[i][0] == "ー":
450
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
451
+ # sep_phonemes[i][0] = ':'
452
+ if "ー" in sep_phonemes[i]:
453
+ for j in range(len(sep_phonemes[i])):
454
+ if sep_phonemes[i][j] == "ー":
455
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
456
+ # sep_phonemes[i][j] = ':'
457
+ res.append(sep_phonemes[i])
458
+ res.append("▁")
459
+ return res
460
+
461
+
462
+ def align_tones(
463
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
464
+ ) -> list[tuple[str, int]]:
465
+ """
466
+ 例:
467
+ …私は、、そう思う。
468
+ phones_with_punct:
469
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
470
+ phone_tone_list:
471
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
472
+ Return:
473
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
474
+ """
475
+ result: list[tuple[str, int]] = []
476
+ tone_index = 0
477
+ for phone in phones_with_punct:
478
+ if tone_index >= len(phone_tone_list):
479
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
480
+ result.append((phone, 0))
481
+ elif phone == phone_tone_list[tone_index][0]:
482
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
483
+ result.append((phone, phone_tone_list[tone_index][1]))
484
+ # 探すindexを1つ進める
485
+ tone_index += 1
486
+ elif phone in punctuation or phone == "▁":
487
+ # phoneがpunctuationの場合 → (phone, 0)を追加
488
+ result.append((phone, 0))
489
+ else:
490
+ print(f"phones: {phones_with_punct}")
491
+ print(f"phone_tone_list: {phone_tone_list}")
492
+ print(f"result: {result}")
493
+ print(f"tone_index: {tone_index}")
494
+ print(f"phone: {phone}")
495
+ raise ValueError(f"Unexpected phone: {phone}")
496
+ return result
497
+
498
+
499
+ def kata2phoneme_list(text: str) -> list[str]:
500
+ """
501
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
502
+ 注意点:
503
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
504
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
505
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
506
+ 例:
507
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
508
+ `?` → ["?"]
509
+ """
510
+ if text in punctuation:
511
+ return [text]
512
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
513
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
514
+ raise ValueError(f"Input must be katakana only: {text}")
515
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
516
+ pattern = "|".join(map(re.escape, sorted_keys))
517
+
518
+ def mora2phonemes(mora: str) -> str:
519
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
520
+ if cosonant is None:
521
+ return f" {vowel}"
522
+ return f" {cosonant} {vowel}"
523
+
524
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
525
+
526
+ # 長音記号「ー」の処理
527
+ long_pattern = r"(\w)(ー*)"
528
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
529
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
530
+ # spaced_phonemes += ' ▁'
531
+ return spaced_phonemes.strip().split(" ")
532
+
533
+
534
+ def frontend2phoneme(labels, drop_unvoiced_vowels=False):
535
+ N = len(labels)
536
+
537
+ phones = []
538
+ for n in range(N):
539
+ lab_curr = labels[n]
540
+ # print(lab_curr)
541
+ # current phoneme
542
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
543
+
544
+ # deal unvoiced vowels as normal vowels
545
+ if drop_unvoiced_vowels and p3 in "AEIOU":
546
+ p3 = p3.lower()
547
+
548
+ # deal with sil at the beginning and the end of text
549
+ if p3 == "sil":
550
+ # assert n == 0 or n == N - 1
551
+ # if n == 0:
552
+ # phones.append("^")
553
+ # elif n == N - 1:
554
+ # # check question form or not
555
+ # e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
556
+ # if e3 == 0:
557
+ # phones.append("$")
558
+ # elif e3 == 1:
559
+ # phones.append("?")
560
+ continue
561
+ elif p3 == "pau":
562
+ phones.append("_")
563
+ continue
564
+ else:
565
+ phones.append(p3)
566
+
567
+ # accent type and position info (forward or backward)
568
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
569
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
570
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
571
+
572
+ # number of mora in accent phrase
573
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
574
+
575
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
576
+ # accent phrase border
577
+ # print(p3, a1, a2, a3, f1, a2_next, lab_curr)
578
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
579
+ phones.append("#")
580
+ # pitch falling
581
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
582
+ phones.append("]")
583
+ # pitch rising
584
+ elif a2 == 1 and a2_next == 2:
585
+ phones.append("[")
586
+
587
+ # phones = ' '.join(phones)
588
+ return phones
589
+
590
+
591
+ class JapanesePhoneConverter(object):
592
+ def __init__(self, lexicon_path=None, ipa_dict_path=None):
593
+ # lexicon_lines = open(lexicon_path, 'r', encoding='utf-8').readlines()
594
+ # self.lexicon = {}
595
+ # self.single_dict = {}
596
+ # self.double_dict = {}
597
+ # for curr_line in lexicon_lines:
598
+ # k,v = curr_line.strip().split('+',1)
599
+ # self.lexicon[k] = v
600
+ # if len(k) == 2:
601
+ # self.double_dict[k] = v
602
+ # elif len(k) == 1:
603
+ # self.single_dict[k] = v
604
+ self.ipa_dict = {}
605
+ for curr_line in jp_xphone2ipa:
606
+ k, v = curr_line.strip().split(" ", 1)
607
+ self.ipa_dict[k] = re.sub("\s", "", v)
608
+ # kakasi1 = kakasi()
609
+ # kakasi1.setMode("H","K")
610
+ # kakasi1.setMode("J","K")
611
+ # kakasi1.setMode("r","Hepburn")
612
+ self.japan_JH2K = kakasi()
613
+ self.table = {ord(f): ord(t) for f, t in zip("67", "_¯")}
614
+
615
+ def text2sep_kata(self, parsed) -> tuple[list[str], list[str]]:
616
+ """
617
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
618
+ 分割された単語リストとその読み(カタカナor記号1文字)のリス���のタプルを返す。
619
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
620
+ 例:
621
+ `私はそう思う!って感じ?` →
622
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
623
+ """
624
+ # parsed: OpenJTalkの解析結果
625
+ sep_text: list[str] = []
626
+ sep_kata: list[str] = []
627
+ fix_parsed = []
628
+ i = 0
629
+ while i <= len(parsed) - 1:
630
+ # word: 実際の単語の文字列
631
+ # yomi: その読み、但し無声化サインの`’`は除去
632
+ # print(parsed)
633
+ yomi = parsed[i]["pron"]
634
+ tmp_parsed = parsed[i]
635
+ if i != len(parsed) - 1 and parsed[i + 1]["string"] in [
636
+ "々",
637
+ "ゝ",
638
+ "ヽ",
639
+ "ゞ",
640
+ "ヾ",
641
+ "゛",
642
+ ]:
643
+ word = parsed[i]["string"] + parsed[i + 1]["string"]
644
+ i += 1
645
+ else:
646
+ word = parsed[i]["string"]
647
+ word, yomi = replace_punctuation(word), yomi.replace("’", "")
648
+ """
649
+ ここで`yomi`の取りうる値は以下の通りのはず。
650
+ - `word`が通常単語 → 通常の読み(カタカナ)
651
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
652
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
653
+ - `word`が句読点や空白等 → `、`
654
+ - `word`が`?` → `?`(全角になる)
655
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
656
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
657
+ 処理すべきは`yomi`が`、`の場合のみのはず。
658
+ """
659
+ assert yomi != "", f"Empty yomi: {word}"
660
+ if yomi == "、":
661
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`のいずれか
662
+ if word not in (
663
+ ".",
664
+ ",",
665
+ "!",
666
+ "'",
667
+ "-",
668
+ "?",
669
+ ":",
670
+ ";",
671
+ "…",
672
+ "",
673
+ ):
674
+ # ここはpyopenjtalkが読めない文字等のときに起こる
675
+ #print(
676
+ # "{}Cannot read:{}, yomi:{}, new_word:{};".format(
677
+ # parsed, word, yomi, self.japan_JH2K.convert(word)[0]["kana"]
678
+ # )
679
+ #)
680
+ # raise ValueError(word)
681
+ word = self.japan_JH2K.convert(word)[0]["kana"]
682
+ # print(word, self.japan_JH2K.convert(word)[0]['kana'], kata2phoneme_list(self.japan_JH2K.convert(word)[0]['kana']))
683
+ tmp_parsed["pron"] = word
684
+ # yomi = "-"
685
+ # word = ','
686
+ # yomiは元の記号のままに変更
687
+ # else:
688
+ # parsed[i]['pron'] = parsed[i]["string"]
689
+ yomi = word
690
+ elif yomi == "?":
691
+ assert word == "?", f"yomi `?` comes from: {word}"
692
+ yomi = "?"
693
+ if word == "":
694
+ i += 1
695
+ continue
696
+ sep_text.append(word)
697
+ sep_kata.append(yomi)
698
+ # print(word, yomi, parts)
699
+ fix_parsed.append(tmp_parsed)
700
+ i += 1
701
+ # print(sep_text, sep_kata)
702
+ return sep_text, sep_kata, fix_parsed
703
+
704
+ def getSentencePhone(self, sentence, blank_mode=True, phoneme_mode=False):
705
+ # print("origin:", sentence)
706
+ words = []
707
+ words_phone_len = []
708
+ short_char_flag = False
709
+ output_duration_flag = []
710
+ output_before_sil_flag = []
711
+ normed_text = []
712
+ sentence = sentence.strip().strip("'")
713
+ sentence = re.sub(r"\s+", "", sentence)
714
+ output_res = []
715
+ failed_words = []
716
+ last_long_pause = 4
717
+ last_word = None
718
+ frontend_text = pyopenjtalk.run_frontend(sentence)
719
+ # print("frontend_text: ", frontend_text)
720
+ try:
721
+ frontend_text = pyopenjtalk.estimate_accent(frontend_text)
722
+ except:
723
+ pass
724
+ # print("estimate_accent: ", frontend_text)
725
+ # sep_text: 単語単位の単語のリスト
726
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
727
+ sep_text, sep_kata, frontend_text = self.text2sep_kata(frontend_text)
728
+ # print("sep_text: ", sep_text)
729
+ # print("sep_kata: ", sep_kata)
730
+ # print("frontend_text: ", frontend_text)
731
+ # sep_phonemes: 各単語ご���の音素のリストのリスト
732
+ sep_phonemes = handle_long_word([kata2phoneme_list(i) for i in sep_kata])
733
+ # print("sep_phonemes: ", sep_phonemes)
734
+
735
+ pron_text = [x["pron"].strip().replace("’", "") for x in frontend_text]
736
+ # pdb.set_trace()
737
+ prosodys = pyopenjtalk.make_label(frontend_text)
738
+ prosodys = frontend2phoneme(prosodys, drop_unvoiced_vowels=True)
739
+ # print("prosodys: ", ' '.join(prosodys))
740
+ # print("pron_text: ", pron_text)
741
+ normed_text = [x["string"].strip() for x in frontend_text]
742
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト
743
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(prosodys)
744
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
745
+
746
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
747
+ phone_w_punct: list[str] = []
748
+ w_p_len = []
749
+ for i in sep_phonemes:
750
+ phone_w_punct += i
751
+ w_p_len.append(len(i))
752
+ phone_w_punct = phone_w_punct[:-1]
753
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
754
+ # print("phone_w_punct: ", phone_w_punct)
755
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
756
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
757
+
758
+ jp_item = {}
759
+ jp_p = ""
760
+ jp_t = ""
761
+ # mye rye pye bye nye
762
+ # je she
763
+ # print(phone_tone_list)
764
+ for p, t in phone_tone_list:
765
+ if p in self.ipa_dict:
766
+ curr_p = self.ipa_dict[p]
767
+ jp_p += curr_p
768
+ jp_t += str(t + 6) * len(curr_p)
769
+ elif p in punctuation:
770
+ jp_p += p
771
+ jp_t += "0"
772
+ elif p == "▁":
773
+ jp_p += p
774
+ jp_t += " "
775
+ else:
776
+ print(p, t)
777
+ jp_p += "|"
778
+ jp_t += "0"
779
+ # return phones, tones, w_p_len
780
+ jp_p = jp_p.replace("▁", " ")
781
+ jp_t = jp_t.translate(self.table)
782
+ jp_l = ""
783
+ for t in jp_t:
784
+ if t == " ":
785
+ jp_l += " "
786
+ else:
787
+ jp_l += "2"
788
+ # print(jp_p)
789
+ # print(jp_t)
790
+ # print(jp_l)
791
+ # print(len(jp_p_len), sum(w_p_len), len(jp_p), sum(jp_p_len))
792
+ assert len(jp_p) == len(jp_t) and len(jp_p) == len(jp_l)
793
+
794
+ jp_item["jp_p"] = jp_p.replace("| |", "|").rstrip("|")
795
+ jp_item["jp_t"] = jp_t
796
+ jp_item["jp_l"] = jp_l
797
+ jp_item["jp_normed_text"] = " ".join(normed_text)
798
+ jp_item["jp_pron_text"] = " ".join(pron_text)
799
+ # jp_item['jp_ruoma'] = sep_phonemes
800
+ # print(len(normed_text), len(sep_phonemes))
801
+ # print(normed_text)
802
+ return jp_item
803
+
804
+
805
+ jpc = JapanesePhoneConverter()
806
+
807
+
808
+ def japanese_to_ipa(text, text_tokenizer):
809
+ # phonemes = text_tokenizer(text)
810
+ if type(text) == str:
811
+ return jpc.getSentencePhone(text)["jp_p"]
812
+ else:
813
+ result_ph = []
814
+ for t in text:
815
+ result_ph.append(jpc.getSentencePhone(t)["jp_p"])
816
+ return result_ph
g2p/g2p/korean.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ english_dictionary = {
12
+ "KOREA": "코리아",
13
+ "IDOL": "아이돌",
14
+ "IT": "아이티",
15
+ "IQ": "아이큐",
16
+ "UP": "업",
17
+ "DOWN": "다운",
18
+ "PC": "피씨",
19
+ "CCTV": "씨씨티비",
20
+ "SNS": "에스엔에스",
21
+ "AI": "에이아이",
22
+ "CEO": "씨이오",
23
+ "A": "에이",
24
+ "B": "비",
25
+ "C": "씨",
26
+ "D": "디",
27
+ "E": "이",
28
+ "F": "에프",
29
+ "G": "지",
30
+ "H": "에이치",
31
+ "I": "아이",
32
+ "J": "제이",
33
+ "K": "케이",
34
+ "L": "엘",
35
+ "M": "엠",
36
+ "N": "엔",
37
+ "O": "오",
38
+ "P": "피",
39
+ "Q": "큐",
40
+ "R": "알",
41
+ "S": "에스",
42
+ "T": "티",
43
+ "U": "유",
44
+ "V": "브이",
45
+ "W": "더블유",
46
+ "X": "엑스",
47
+ "Y": "와이",
48
+ "Z": "제트",
49
+ }
50
+
51
+
52
+ def normalize(text):
53
+ text = text.strip()
54
+ text = re.sub(
55
+ "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
56
+ )
57
+ text = normalize_english(text)
58
+ text = text.lower()
59
+ return text
60
+
61
+
62
+ def normalize_english(text):
63
+ def fn(m):
64
+ word = m.group()
65
+ if word in english_dictionary:
66
+ return english_dictionary.get(word)
67
+ return word
68
+
69
+ text = re.sub("([A-Za-z]+)", fn, text)
70
+ return text
71
+
72
+
73
+ def korean_to_ipa(text, text_tokenizer):
74
+ if type(text) == str:
75
+ text = normalize(text)
76
+ phonemes = text_tokenizer(text)
77
+ return phonemes
78
+ else:
79
+ for i, t in enumerate(text):
80
+ text[i] = normalize(t)
81
+ return text_tokenizer(text)
g2p/g2p/mandarin.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import jieba
8
+ import cn2an
9
+ from pypinyin import lazy_pinyin, BOPOMOFO
10
+ from typing import List
11
+ from g2p.g2p.chinese_model_g2p import BertPolyPredict
12
+ from g2p.utils.front_utils import *
13
+ import os
14
+
15
+ # from g2pw import G2PWConverter
16
+
17
+
18
+ # set blank level, {0:"none",1:"char", 2:"word"}
19
+ BLANK_LEVEL = 0
20
+
21
+ # conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
22
+ resource_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23
+ poly_all_class_path = os.path.join(
24
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
25
+ )
26
+ if not os.path.exists(poly_all_class_path):
27
+ print(
28
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
29
+ poly_all_class_path
30
+ )
31
+ )
32
+ exit()
33
+ poly_dict = generate_poly_lexicon(poly_all_class_path)
34
+
35
+ # Set up G2PW model parameters
36
+ g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
37
+ if not os.path.exists(g2pw_poly_model_path):
38
+ print(
39
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
40
+ g2pw_poly_model_path
41
+ )
42
+ )
43
+ exit()
44
+
45
+ json_file_path = os.path.join(
46
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
47
+ )
48
+ if not os.path.exists(json_file_path):
49
+ print(
50
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
51
+ json_file_path
52
+ )
53
+ )
54
+ exit()
55
+
56
+ jsonr_file_path = os.path.join(
57
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
58
+ )
59
+ if not os.path.exists(jsonr_file_path):
60
+ print(
61
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
62
+ jsonr_file_path
63
+ )
64
+ )
65
+ exit()
66
+
67
+ g2pw_poly_predict = BertPolyPredict(
68
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
69
+ )
70
+
71
+
72
+ """
73
+ Text clean time
74
+ """
75
+ # List of (Latin alphabet, bopomofo) pairs:
76
+ _latin_to_bopomofo = [
77
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
78
+ for x in [
79
+ ("a", "ㄟˉ"),
80
+ ("b", "ㄅㄧˋ"),
81
+ ("c", "ㄙㄧˉ"),
82
+ ("d", "ㄉㄧˋ"),
83
+ ("e", "ㄧˋ"),
84
+ ("f", "ㄝˊㄈㄨˋ"),
85
+ ("g", "ㄐㄧˋ"),
86
+ ("h", "ㄝˇㄑㄩˋ"),
87
+ ("i", "ㄞˋ"),
88
+ ("j", "ㄐㄟˋ"),
89
+ ("k", "ㄎㄟˋ"),
90
+ ("l", "ㄝˊㄛˋ"),
91
+ ("m", "ㄝˊㄇㄨˋ"),
92
+ ("n", "ㄣˉ"),
93
+ ("o", "ㄡˉ"),
94
+ ("p", "ㄆㄧˉ"),
95
+ ("q", "ㄎㄧㄡˉ"),
96
+ ("r", "ㄚˋ"),
97
+ ("s", "ㄝˊㄙˋ"),
98
+ ("t", "ㄊㄧˋ"),
99
+ ("u", "ㄧㄡˉ"),
100
+ ("v", "ㄨㄧˉ"),
101
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
102
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
103
+ ("y", "ㄨㄞˋ"),
104
+ ("z", "ㄗㄟˋ"),
105
+ ]
106
+ ]
107
+
108
+ # List of (bopomofo, ipa) pairs:
109
+ _bopomofo_to_ipa = [
110
+ (re.compile("%s" % x[0]), x[1])
111
+ for x in [
112
+ ("ㄅㄛ", "p⁼wo"),
113
+ ("ㄆㄛ", "pʰwo"),
114
+ ("ㄇㄛ", "mwo"),
115
+ ("ㄈㄛ", "fwo"),
116
+ ("ㄧㄢ", "|jɛn"),
117
+ ("ㄩㄢ", "|ɥæn"),
118
+ ("ㄧㄣ", "|in"),
119
+ ("ㄩㄣ", "|ɥn"),
120
+ ("ㄧㄥ", "|iŋ"),
121
+ ("ㄨㄥ", "|ʊŋ"),
122
+ ("ㄩㄥ", "|jʊŋ"),
123
+ # Add
124
+ ("ㄧㄚ", "|ia"),
125
+ ("ㄧㄝ", "|iɛ"),
126
+ ("ㄧㄠ", "|iɑʊ"),
127
+ ("ㄧㄡ", "|ioʊ"),
128
+ ("ㄧㄤ", "|iɑŋ"),
129
+ ("ㄨㄚ", "|ua"),
130
+ ("ㄨㄛ", "|uo"),
131
+ ("ㄨㄞ", "|uaɪ"),
132
+ ("ㄨㄟ", "|ueɪ"),
133
+ ("ㄨㄢ", "|uan"),
134
+ ("ㄨㄣ", "|uən"),
135
+ ("ㄨㄤ", "|uɑŋ"),
136
+ ("ㄩㄝ", "|ɥɛ"),
137
+ # End
138
+ ("ㄅ", "p⁼"),
139
+ ("ㄆ", "pʰ"),
140
+ ("ㄇ", "m"),
141
+ ("ㄈ", "f"),
142
+ ("ㄉ", "t⁼"),
143
+ ("ㄊ", "tʰ"),
144
+ ("ㄋ", "n"),
145
+ ("ㄌ", "l"),
146
+ ("ㄍ", "k⁼"),
147
+ ("ㄎ", "kʰ"),
148
+ ("ㄏ", "x"),
149
+ ("ㄐ", "tʃ⁼"),
150
+ ("ㄑ", "tʃʰ"),
151
+ ("ㄒ", "ʃ"),
152
+ ("ㄓ", "ts`⁼"),
153
+ ("ㄔ", "ts`ʰ"),
154
+ ("ㄕ", "s`"),
155
+ ("ㄖ", "ɹ`"),
156
+ ("ㄗ", "ts⁼"),
157
+ ("ㄘ", "tsʰ"),
158
+ ("ㄙ", "|s"),
159
+ ("ㄚ", "|a"),
160
+ ("ㄛ", "|o"),
161
+ ("ㄜ", "|ə"),
162
+ ("ㄝ", "|ɛ"),
163
+ ("ㄞ", "|aɪ"),
164
+ ("ㄟ", "|eɪ"),
165
+ ("ㄠ", "|ɑʊ"),
166
+ ("ㄡ", "|oʊ"),
167
+ ("ㄢ", "|an"),
168
+ ("ㄣ", "|ən"),
169
+ ("ㄤ", "|ɑŋ"),
170
+ ("ㄥ", "|əŋ"),
171
+ ("ㄦ", "əɹ"),
172
+ ("ㄧ", "|i"),
173
+ ("ㄨ", "|u"),
174
+ ("ㄩ", "|ɥ"),
175
+ ("ˉ", "→|"),
176
+ ("ˊ", "↑|"),
177
+ ("ˇ", "↓↑|"),
178
+ ("ˋ", "↓|"),
179
+ ("˙", "|"),
180
+ ]
181
+ ]
182
+ must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
183
+
184
+ word_pinyin_dict = {}
185
+ with open(
186
+ os.path.join(resource_path, "sources", "chinese_lexicon.txt"), "r", encoding="utf-8"
187
+ ) as fread:
188
+ txt_list = fread.readlines()
189
+ for txt in txt_list:
190
+ word, pinyin = txt.strip().split("\t")
191
+ word_pinyin_dict[word] = pinyin
192
+ fread.close()
193
+
194
+ pinyin_2_bopomofo_dict = {}
195
+ with open(
196
+ os.path.join(resource_path, "sources", "pinyin_2_bpmf.txt"), "r", encoding="utf-8"
197
+ ) as fread:
198
+ txt_list = fread.readlines()
199
+ for txt in txt_list:
200
+ pinyin, bopomofo = txt.strip().split("\t")
201
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
202
+ fread.close()
203
+
204
+ tone_dict = {
205
+ "0": "˙",
206
+ "5": "˙",
207
+ "1": "",
208
+ "2": "ˊ",
209
+ "3": "ˇ",
210
+ "4": "ˋ",
211
+ }
212
+
213
+ bopomofos2pinyin_dict = {}
214
+ with open(
215
+ os.path.join(resource_path, "sources", "bpmf_2_pinyin.txt"), "r", encoding="utf-8"
216
+ ) as fread:
217
+ txt_list = fread.readlines()
218
+ for txt in txt_list:
219
+ v, k = txt.strip().split("\t")
220
+ bopomofos2pinyin_dict[k] = v
221
+ fread.close()
222
+
223
+
224
+ def bpmf_to_pinyin(text):
225
+ bopomofo_list = text.split("|")
226
+ pinyin_list = []
227
+ for info in bopomofo_list:
228
+ pinyin = ""
229
+ for c in info:
230
+ if c in bopomofos2pinyin_dict:
231
+ pinyin += bopomofos2pinyin_dict[c]
232
+ if len(pinyin) == 0:
233
+ continue
234
+ if pinyin[-1] not in "01234":
235
+ pinyin += "1"
236
+ if pinyin[:-1] == "ve":
237
+ pinyin = "y" + pinyin
238
+ if pinyin[:-1] == "sh":
239
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
240
+ if pinyin == "sh":
241
+ pinyin = pinyin[:-1] + "i"
242
+ if pinyin[:-1] == "s":
243
+ pinyin = "si" + pinyin[-1]
244
+ if pinyin[:-1] == "c":
245
+ pinyin = "ci" + pinyin[-1]
246
+ if pinyin[:-1] == "i":
247
+ pinyin = "yi" + pinyin[-1]
248
+ if pinyin[:-1] == "iou":
249
+ pinyin = "you" + pinyin[-1]
250
+ if pinyin[:-1] == "ien":
251
+ pinyin = "yin" + pinyin[-1]
252
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
253
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
254
+ if "uei" in pinyin:
255
+ if pinyin[:-1] == "uei":
256
+ pinyin = "wei" + pinyin[-1]
257
+ elif pinyin[-4:-1] == "uei":
258
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
259
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
260
+ if pinyin[:-1] == "uen":
261
+ pinyin = "wen" + pinyin[-1]
262
+ elif pinyin[-4:-1] == "uei":
263
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
264
+ if "van" in pinyin and pinyin[-4:-1] == "van":
265
+ if pinyin[:-1] == "van":
266
+ pinyin = "yuan" + pinyin[-1]
267
+ elif pinyin[-4:-1] == "van":
268
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
269
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
270
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
271
+ if pinyin[:-1] == "veng":
272
+ pinyin = "yong" + pinyin[-1]
273
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
274
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
275
+ if pinyin[:-1] == "ieng":
276
+ pinyin = "ying" + pinyin[-1]
277
+ if pinyin[:-1] == "u":
278
+ pinyin = "wu" + pinyin[-1]
279
+ if pinyin[:-1] == "v":
280
+ pinyin = "yv" + pinyin[-1]
281
+ if pinyin[:-1] == "ing":
282
+ pinyin = "ying" + pinyin[-1]
283
+ if pinyin[:-1] == "z":
284
+ pinyin = "zi" + pinyin[-1]
285
+ if pinyin[:-1] == "zh":
286
+ pinyin = "zhi" + pinyin[-1]
287
+ if pinyin[0] == "u":
288
+ pinyin = "w" + pinyin[1:]
289
+ if pinyin[0] == "i":
290
+ pinyin = "y" + pinyin[1:]
291
+ pinyin = pinyin.replace("ien", "in")
292
+
293
+ pinyin_list.append(pinyin)
294
+ return " ".join(pinyin_list)
295
+
296
+
297
+ # Convert numbers to Chinese pronunciation
298
+ def number_to_chinese(text):
299
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
300
+ # for number in numbers:
301
+ # text = text.replace(number, cn2an.an2cn(number), 1)
302
+ text = cn2an.transform(text, "an2cn")
303
+ return text
304
+
305
+
306
+ def normalization(text):
307
+ text = text.replace(",", ",")
308
+ text = text.replace("。", ".")
309
+ text = text.replace("!", "!")
310
+ text = text.replace("?", "?")
311
+ text = text.replace(";", ";")
312
+ text = text.replace(":", ":")
313
+ text = text.replace("、", ",")
314
+ text = text.replace("‘", "'")
315
+ text = text.replace("’", "'")
316
+ text = text.replace("⋯", "…")
317
+ text = text.replace("···", "…")
318
+ text = text.replace("・・・", "…")
319
+ text = text.replace("...", "…")
320
+ text = re.sub(r"\s+", "", text)
321
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
322
+ text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
323
+ return text
324
+
325
+
326
+ def change_tone(bopomofo: str, tone: str) -> str:
327
+ if bopomofo[-1] not in "˙ˊˇˋ":
328
+ bopomofo = bopomofo + tone
329
+ else:
330
+ bopomofo = bopomofo[:-1] + tone
331
+ return bopomofo
332
+
333
+
334
+ def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
335
+ if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
336
+ bopomofos[-1] = change_tone(bopomofos[-1], "˙")
337
+ return bopomofos
338
+
339
+
340
+ def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
341
+ valid_char = set(word)
342
+ if len(valid_char) == 1 and "不" in valid_char:
343
+ pass
344
+ elif word in ["不字"]:
345
+ pass
346
+ elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
347
+ bopomofos[1] = bopomofos[1][:-1] + "˙"
348
+ else:
349
+ for i, char in enumerate(word):
350
+ if (
351
+ i + 1 < len(bopomofos)
352
+ and char == "不"
353
+ and i + 1 < len(word)
354
+ and 0 < len(bopomofos[i + 1])
355
+ and bopomofos[i + 1][-1] == "ˋ"
356
+ ):
357
+ bopomofos[i] = bopomofos[i][:-1] + "ˊ"
358
+ return bopomofos
359
+
360
+
361
+ def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
362
+ punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
363
+ if word.find("一") != -1 and any(
364
+ [item.isnumeric() for item in word if item != "一"]
365
+ ):
366
+ for i in range(len(word)):
367
+ if (
368
+ i == 0
369
+ and word[0] == "一"
370
+ and len(word) > 1
371
+ and word[1]
372
+ not in [
373
+ "零",
374
+ "一",
375
+ "二",
376
+ "三",
377
+ "四",
378
+ "五",
379
+ "六",
380
+ "七",
381
+ "八",
382
+ "九",
383
+ "十",
384
+ ]
385
+ ):
386
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
387
+ bopomofos[0] = change_tone(bopomofos[0], "ˊ")
388
+ else:
389
+ bopomofos[0] = change_tone(bopomofos[0], "ˋ")
390
+ elif word[i] == "一":
391
+ bopomofos[i] = change_tone(bopomofos[i], "")
392
+ return bopomofos
393
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
394
+ bopomofos[1] = change_tone(bopomofos[1], "˙")
395
+ elif word.startswith("第一"):
396
+ bopomofos[1] = change_tone(bopomofos[1], "")
397
+ elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
398
+ bopomofos[0] = change_tone(bopomofos[0], "")
399
+ else:
400
+ for i, char in enumerate(word):
401
+ if char == "一" and i + 1 < len(word):
402
+ if (
403
+ len(bopomofos) > i + 1
404
+ and len(bopomofos[i + 1]) > 0
405
+ and bopomofos[i + 1][-1] in {"ˋ"}
406
+ ):
407
+ bopomofos[i] = change_tone(bopomofos[i], "ˊ")
408
+ else:
409
+ if word[i + 1] not in punc:
410
+ bopomofos[i] = change_tone(bopomofos[i], "ˋ")
411
+ else:
412
+ pass
413
+ return bopomofos
414
+
415
+
416
+ def merge_bu(seg: List) -> List:
417
+ new_seg = []
418
+ last_word = ""
419
+ for word in seg:
420
+ if word != "不":
421
+ if last_word == "不":
422
+ word = last_word + word
423
+ new_seg.append(word)
424
+ last_word = word
425
+ return new_seg
426
+
427
+
428
+ def merge_er(seg: List) -> List:
429
+ new_seg = []
430
+ for i, word in enumerate(seg):
431
+ if i - 1 >= 0 and word == "儿":
432
+ new_seg[-1] = new_seg[-1] + seg[i]
433
+ else:
434
+ new_seg.append(word)
435
+ return new_seg
436
+
437
+
438
+ def merge_yi(seg: List) -> List:
439
+ new_seg = []
440
+ # function 1
441
+ for i, word in enumerate(seg):
442
+ if (
443
+ i - 1 >= 0
444
+ and word == "一"
445
+ and i + 1 < len(seg)
446
+ and seg[i - 1] == seg[i + 1]
447
+ ):
448
+ if i - 1 < len(new_seg):
449
+ new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
450
+ else:
451
+ new_seg.append(word)
452
+ new_seg.append(seg[i + 1])
453
+ else:
454
+ if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
455
+ continue
456
+ else:
457
+ new_seg.append(word)
458
+ seg = new_seg
459
+ new_seg = []
460
+ isnumeric_flag = False
461
+ for i, word in enumerate(seg):
462
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
463
+ isnumeric_flag = True
464
+ new_seg.append(word)
465
+ else:
466
+ new_seg.append(word)
467
+ seg = new_seg
468
+ new_seg = []
469
+ # function 2
470
+ for i, word in enumerate(seg):
471
+ if new_seg and new_seg[-1] == "一":
472
+ new_seg[-1] = new_seg[-1] + word
473
+ else:
474
+ new_seg.append(word)
475
+ return new_seg
476
+
477
+
478
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
479
+ def chinese_to_bopomofo(text_short, sentence):
480
+ # bopomofos = conv(text_short)
481
+ words = jieba.lcut(text_short, cut_all=False)
482
+ words = merge_yi(words)
483
+ words = merge_bu(words)
484
+ words = merge_er(words)
485
+ text = ""
486
+
487
+ char_index = 0
488
+ for word in words:
489
+ bopomofos = []
490
+ if word in word_pinyin_dict and word not in poly_dict:
491
+ pinyin = word_pinyin_dict[word]
492
+ for py in pinyin.split(" "):
493
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
494
+ bopomofos.append(
495
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
496
+ )
497
+ if BLANK_LEVEL == 1:
498
+ bopomofos.append("_")
499
+ else:
500
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
501
+ bopomofos += bopomofos_lazy
502
+ if BLANK_LEVEL == 1:
503
+ bopomofos.append("_")
504
+ else:
505
+ for i in range(len(word)):
506
+ c = word[i]
507
+ if c in poly_dict:
508
+ poly_pinyin = g2pw_poly_predict.predict_process(
509
+ [text_short, char_index + i]
510
+ )[0]
511
+ py = poly_pinyin[2:-1]
512
+ bopomofos.append(
513
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
514
+ )
515
+ if BLANK_LEVEL == 1:
516
+ bopomofos.append("_")
517
+ elif c in word_pinyin_dict:
518
+ py = word_pinyin_dict[c]
519
+ bopomofos.append(
520
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
521
+ )
522
+ if BLANK_LEVEL == 1:
523
+ bopomofos.append("_")
524
+ else:
525
+ bopomofos.append(c)
526
+ if BLANK_LEVEL == 1:
527
+ bopomofos.append("_")
528
+ if BLANK_LEVEL == 2:
529
+ bopomofos.append("_")
530
+ char_index += len(word)
531
+
532
+ if (
533
+ len(word) == 3
534
+ and bopomofos[0][-1] == "ˇ"
535
+ and bopomofos[1][-1] == "ˇ"
536
+ and bopomofos[-1][-1] == "ˇ"
537
+ ):
538
+ bopomofos[0] = bopomofos[0] + "ˊ"
539
+ bopomofos[1] = bopomofos[1] + "ˊ"
540
+ if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
541
+ bopomofos[0] = bopomofos[0][:-1] + "ˊ"
542
+ bopomofos = bu_sandhi(word, bopomofos)
543
+ bopomofos = yi_sandhi(word, bopomofos)
544
+ bopomofos = er_sandhi(word, bopomofos)
545
+ if not re.search("[\u4e00-\u9fff]", word):
546
+ text += "|" + word
547
+ continue
548
+ for i in range(len(bopomofos)):
549
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
550
+ if text != "":
551
+ text += "|"
552
+ text += "|".join(bopomofos)
553
+ return text
554
+
555
+
556
+ # Convert latin pronunciation to pinyin (bopomofo)
557
+ def latin_to_bopomofo(text):
558
+ for regex, replacement in _latin_to_bopomofo:
559
+ text = re.sub(regex, replacement, text)
560
+ return text
561
+
562
+
563
+ # Convert pinyin (bopomofo) to IPA
564
+ def bopomofo_to_ipa(text):
565
+ for regex, replacement in _bopomofo_to_ipa:
566
+ text = re.sub(regex, replacement, text)
567
+ return text
568
+
569
+
570
+ def _chinese_to_ipa(text, sentence):
571
+ text = re.sub(r"\s", "_", text)
572
+
573
+ text = number_to_chinese(text.strip())
574
+ text = normalization(text)
575
+ text = chinese_to_bopomofo(text, sentence)
576
+ # pinyin = bpmf_to_pinyin(text)
577
+ text = latin_to_bopomofo(text)
578
+ text = bopomofo_to_ipa(text)
579
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
580
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
581
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
582
+ text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
583
+ text = re.sub(r"\|+", "|", text)
584
+ text = text.rstrip("|")
585
+ return text
586
+
587
+
588
+ # Convert Chinese to IPA
589
+ def chinese_to_ipa(text, sentence, text_tokenizer):
590
+ # phonemes = text_tokenizer(text.strip())
591
+ if type(text) == str:
592
+ return _chinese_to_ipa(text, sentence)
593
+ else:
594
+ result_ph = []
595
+ for t in text:
596
+ result_ph.append(_chinese_to_ipa(t, sentence))
597
+ return result_ph
g2p/g2p/text_tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import os
8
+ from typing import List, Pattern, Union
9
+ from phonemizer.utils import list2str, str2list
10
+ from phonemizer.backend import EspeakBackend
11
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
12
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
13
+ from phonemizer.punctuation import Punctuation
14
+ from phonemizer.separator import Separator
15
+
16
+
17
+ class TextTokenizer:
18
+ """Phonemize Text."""
19
+
20
+ def __init__(
21
+ self,
22
+ language="en-us",
23
+ backend="espeak",
24
+ separator=Separator(word="|_|", syllable="-", phone="|"),
25
+ preserve_punctuation=True,
26
+ with_stress: bool = False,
27
+ tie: Union[bool, str] = False,
28
+ language_switch: LanguageSwitch = "remove-flags",
29
+ words_mismatch: WordMismatch = "ignore",
30
+ ) -> None:
31
+ self.preserve_punctuation_marks = ",.?!;:'…"
32
+ self.backend = EspeakBackend(
33
+ language,
34
+ punctuation_marks=self.preserve_punctuation_marks,
35
+ preserve_punctuation=preserve_punctuation,
36
+ with_stress=with_stress,
37
+ tie=tie,
38
+ language_switch=language_switch,
39
+ words_mismatch=words_mismatch,
40
+ )
41
+
42
+ self.separator = separator
43
+
44
+ # convert chinese punctuation to english punctuation
45
+ def convert_chinese_punctuation(self, text: str) -> str:
46
+ text = text.replace(",", ",")
47
+ text = text.replace("。", ".")
48
+ text = text.replace("!", "!")
49
+ text = text.replace("?", "?")
50
+ text = text.replace(";", ";")
51
+ text = text.replace(":", ":")
52
+ text = text.replace("、", ",")
53
+ text = text.replace("‘", "'")
54
+ text = text.replace("’", "'")
55
+ text = text.replace("⋯", "…")
56
+ text = text.replace("···", "…")
57
+ text = text.replace("・・・", "…")
58
+ text = text.replace("...", "…")
59
+ return text
60
+
61
+ def __call__(self, text, strip=True) -> List[str]:
62
+
63
+ text_type = type(text)
64
+ normalized_text = []
65
+ for line in str2list(text):
66
+ line = self.convert_chinese_punctuation(line.strip())
67
+ line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
68
+ line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
69
+ line = re.sub(r"\s+", " ", line)
70
+ normalized_text.append(line)
71
+ # print("Normalized test: ", normalized_text[0])
72
+ phonemized = self.backend.phonemize(
73
+ normalized_text, separator=self.separator, strip=strip, njobs=1
74
+ )
75
+ if text_type == str:
76
+ phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
77
+ phonemized = re.sub(r"\|+", "|", phonemized)
78
+ phonemized = phonemized.rstrip("|")
79
+ else:
80
+ for i in range(len(phonemized)):
81
+ phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
82
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
83
+ phonemized[i] = phonemized[i].rstrip("|")
84
+ return phonemized
g2p/g2p/vocab.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iː": 5,
9
+ "ɪ": 6,
10
+ "ɜː": 7,
11
+ "ɚ": 8,
12
+ "oːɹ": 9,
13
+ "ɔː": 10,
14
+ "ɔːɹ": 11,
15
+ "ɑː": 12,
16
+ "uː": 13,
17
+ "ʊ": 14,
18
+ "ɑːɹ": 15,
19
+ "ʌ": 16,
20
+ "ɛ": 17,
21
+ "æ": 18,
22
+ "eɪ": 19,
23
+ "aɪ": 20,
24
+ "ɔɪ": 21,
25
+ "aʊ": 22,
26
+ "oʊ": 23,
27
+ "ɪɹ": 24,
28
+ "ɛɹ": 25,
29
+ "ʊɹ": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ɡ": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "θ": 35,
39
+ "ð": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "ʃ": 39,
43
+ "ʒ": 40,
44
+ "h": 41,
45
+ "tʃ": 42,
46
+ "dʒ": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ŋ": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ɹ": 49,
53
+ "l": 50,
54
+ "tɹ": 51,
55
+ "dɹ": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ɔ": 56,
60
+ "ə": 57,
61
+ "ɾ": 58,
62
+ "iə": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oː": 62,
66
+ "ɛː": 63,
67
+ "ɪː": 64,
68
+ "aɪə": 65,
69
+ "aɪɚ": 66,
70
+ "ɑ̃": 67,
71
+ "ç": 68,
72
+ "ɔ̃": 69,
73
+ "ææ": 70,
74
+ "ɐɐ": 71,
75
+ "ɡʲ": 72,
76
+ "nʲ": 73,
77
+ "iːː": 74,
78
+
79
+ "p⁼": 75,
80
+ "pʰ": 76,
81
+ "t⁼": 77,
82
+ "tʰ": 78,
83
+ "k⁼": 79,
84
+ "kʰ": 80,
85
+ "x": 81,
86
+ "tʃ⁼": 82,
87
+ "tʃʰ": 83,
88
+ "ts`⁼": 84,
89
+ "ts`ʰ": 85,
90
+ "s`": 86,
91
+ "ɹ`": 87,
92
+ "ts⁼": 88,
93
+ "tsʰ": 89,
94
+ "p⁼wo": 90,
95
+ "p⁼wo→": 91,
96
+ "p⁼wo↑": 92,
97
+ "p⁼wo↓↑": 93,
98
+ "p⁼wo↓": 94,
99
+ "pʰwo": 95,
100
+ "pʰwo→": 96,
101
+ "pʰwo↑": 97,
102
+ "pʰwo↓↑": 98,
103
+ "pʰwo↓": 99,
104
+ "mwo": 100,
105
+ "mwo→": 101,
106
+ "mwo↑": 102,
107
+ "mwo↓↑": 103,
108
+ "mwo↓": 104,
109
+ "fwo": 105,
110
+ "fwo→": 106,
111
+ "fwo↑": 107,
112
+ "fwo↓↑": 108,
113
+ "fwo↓": 109,
114
+ "jɛn": 110,
115
+ "jɛn→": 111,
116
+ "jɛn↑": 112,
117
+ "jɛn↓↑": 113,
118
+ "jɛn↓": 114,
119
+ "ɥæn": 115,
120
+ "ɥæn→": 116,
121
+ "ɥæn↑": 117,
122
+ "ɥæn↓↑": 118,
123
+ "ɥæn↓": 119,
124
+ "in": 120,
125
+ "in→": 121,
126
+ "in↑": 122,
127
+ "in↓↑": 123,
128
+ "in↓": 124,
129
+ "ɥn": 125,
130
+ "ɥn→": 126,
131
+ "ɥn↑": 127,
132
+ "ɥn↓↑": 128,
133
+ "ɥn↓": 129,
134
+ "iŋ": 130,
135
+ "iŋ→": 131,
136
+ "iŋ↑": 132,
137
+ "iŋ↓↑": 133,
138
+ "iŋ↓": 134,
139
+ "ʊŋ": 135,
140
+ "ʊŋ→": 136,
141
+ "ʊŋ↑": 137,
142
+ "ʊŋ↓↑": 138,
143
+ "ʊŋ↓": 139,
144
+ "jʊŋ": 140,
145
+ "jʊŋ→": 141,
146
+ "jʊŋ↑": 142,
147
+ "jʊŋ↓↑": 143,
148
+ "jʊŋ↓": 144,
149
+ "ia": 145,
150
+ "ia→": 146,
151
+ "ia↑": 147,
152
+ "ia↓↑": 148,
153
+ "ia↓": 149,
154
+ "iɛ": 150,
155
+ "iɛ→": 151,
156
+ "iɛ↑": 152,
157
+ "iɛ↓↑": 153,
158
+ "iɛ↓": 154,
159
+ "iɑʊ": 155,
160
+ "iɑʊ→": 156,
161
+ "iɑʊ↑": 157,
162
+ "iɑʊ↓↑": 158,
163
+ "iɑʊ↓": 159,
164
+ "ioʊ": 160,
165
+ "ioʊ→": 161,
166
+ "ioʊ↑": 162,
167
+ "ioʊ↓↑": 163,
168
+ "ioʊ↓": 164,
169
+ "iɑŋ": 165,
170
+ "iɑŋ→": 166,
171
+ "iɑŋ↑": 167,
172
+ "iɑŋ↓↑": 168,
173
+ "iɑŋ↓": 169,
174
+ "ua": 170,
175
+ "ua→": 171,
176
+ "ua↑": 172,
177
+ "ua↓↑": 173,
178
+ "ua↓": 174,
179
+ "uo": 175,
180
+ "uo→": 176,
181
+ "uo↑": 177,
182
+ "uo↓↑": 178,
183
+ "uo↓": 179,
184
+ "uaɪ": 180,
185
+ "uaɪ→": 181,
186
+ "uaɪ↑": 182,
187
+ "uaɪ↓↑": 183,
188
+ "uaɪ↓": 184,
189
+ "ueɪ": 185,
190
+ "ueɪ→": 186,
191
+ "ueɪ↑": 187,
192
+ "ueɪ↓↑": 188,
193
+ "ueɪ↓": 189,
194
+ "uan": 190,
195
+ "uan→": 191,
196
+ "uan↑": 192,
197
+ "uan↓↑": 193,
198
+ "uan↓": 194,
199
+ "uən": 195,
200
+ "uən→": 196,
201
+ "uən↑": 197,
202
+ "uən↓↑": 198,
203
+ "uən↓": 199,
204
+ "uɑŋ": 200,
205
+ "uɑŋ→": 201,
206
+ "uɑŋ↑": 202,
207
+ "uɑŋ↓↑": 203,
208
+ "uɑŋ↓": 204,
209
+ "ɥɛ": 205,
210
+ "ɥɛ→": 206,
211
+ "ɥɛ↑": 207,
212
+ "ɥɛ↓↑": 208,
213
+ "ɥɛ↓": 209,
214
+ "a": 210,
215
+ "a→": 211,
216
+ "a↑": 212,
217
+ "a↓↑": 213,
218
+ "a↓": 214,
219
+ "o": 215,
220
+ "o→": 216,
221
+ "o↑": 217,
222
+ "o↓↑": 218,
223
+ "o↓": 219,
224
+ "ə→": 220,
225
+ "ə↑": 221,
226
+ "ə↓↑": 222,
227
+ "ə↓": 223,
228
+ "ɛ→": 224,
229
+ "ɛ↑": 225,
230
+ "ɛ↓↑": 226,
231
+ "ɛ↓": 227,
232
+ "aɪ→": 228,
233
+ "aɪ↑": 229,
234
+ "aɪ↓↑": 230,
235
+ "aɪ↓": 231,
236
+ "eɪ→": 232,
237
+ "eɪ↑": 233,
238
+ "eɪ↓↑": 234,
239
+ "eɪ↓": 235,
240
+ "ɑʊ": 236,
241
+ "ɑʊ→": 237,
242
+ "ɑʊ↑": 238,
243
+ "ɑʊ↓↑": 239,
244
+ "ɑʊ↓": 240,
245
+ "oʊ→": 241,
246
+ "oʊ↑": 242,
247
+ "oʊ↓↑": 243,
248
+ "oʊ↓": 244,
249
+ "an": 245,
250
+ "an→": 246,
251
+ "an↑": 247,
252
+ "an↓↑": 248,
253
+ "an↓": 249,
254
+ "ən": 250,
255
+ "ən→": 251,
256
+ "ən↑": 252,
257
+ "ən↓↑": 253,
258
+ "ən↓": 254,
259
+ "ɑŋ": 255,
260
+ "ɑŋ→": 256,
261
+ "ɑŋ↑": 257,
262
+ "ɑŋ↓↑": 258,
263
+ "ɑŋ↓": 259,
264
+ "əŋ": 260,
265
+ "əŋ→": 261,
266
+ "əŋ↑": 262,
267
+ "əŋ↓↑": 263,
268
+ "əŋ↓": 264,
269
+ "əɹ": 265,
270
+ "əɹ→": 266,
271
+ "əɹ↑": 267,
272
+ "əɹ↓↑": 268,
273
+ "əɹ↓": 269,
274
+ "i→": 270,
275
+ "i↑": 271,
276
+ "i↓↑": 272,
277
+ "i↓": 273,
278
+ "u→": 274,
279
+ "u↑": 275,
280
+ "u↓↑": 276,
281
+ "u↓": 277,
282
+ "ɥ": 278,
283
+ "ɥ→": 279,
284
+ "ɥ↑": 280,
285
+ "ɥ↓↑": 281,
286
+ "ɥ↓": 282,
287
+ "ts`⁼ɹ": 283,
288
+ "ts`⁼ɹ→": 284,
289
+ "ts`⁼ɹ↑": 285,
290
+ "ts`⁼ɹ↓↑": 286,
291
+ "ts`⁼ɹ↓": 287,
292
+ "ts`ʰɹ": 288,
293
+ "ts`ʰɹ→": 289,
294
+ "ts`ʰɹ↑": 290,
295
+ "ts`ʰɹ↓↑": 291,
296
+ "ts`ʰɹ↓": 292,
297
+ "s`ɹ": 293,
298
+ "s`ɹ→": 294,
299
+ "s`ɹ↑": 295,
300
+ "s`ɹ↓↑": 296,
301
+ "s`ɹ���": 297,
302
+ "ɹ`ɹ": 298,
303
+ "ɹ`ɹ→": 299,
304
+ "ɹ`ɹ↑": 300,
305
+ "ɹ`ɹ↓↑": 301,
306
+ "ɹ`ɹ↓": 302,
307
+ "ts⁼ɹ": 303,
308
+ "ts⁼ɹ→": 304,
309
+ "ts⁼ɹ↑": 305,
310
+ "ts⁼ɹ↓↑": 306,
311
+ "ts⁼ɹ↓": 307,
312
+ "tsʰɹ": 308,
313
+ "tsʰɹ→": 309,
314
+ "tsʰɹ↑": 310,
315
+ "tsʰɹ↓↑": 311,
316
+ "tsʰɹ↓": 312,
317
+ "sɹ": 313,
318
+ "sɹ→": 314,
319
+ "sɹ↑": 315,
320
+ "sɹ↓↑": 316,
321
+ "sɹ↓": 317,
322
+
323
+ "ɯ": 318,
324
+ "e": 319,
325
+ "aː": 320,
326
+ "ɯː": 321,
327
+ "eː": 322,
328
+ "ç": 323,
329
+ "ɸ": 324,
330
+ "ɰᵝ": 325,
331
+ "ɴ": 326,
332
+ "g": 327,
333
+ "dʑ": 328,
334
+ "q": 329,
335
+ "ː": 330,
336
+ "bj": 331,
337
+ "tɕ": 332,
338
+ "dej": 333,
339
+ "tej": 334,
340
+ "gj": 335,
341
+ "gɯ": 336,
342
+ "çj": 337,
343
+ "kj": 338,
344
+ "kɯ": 339,
345
+ "mj": 340,
346
+ "nj": 341,
347
+ "pj": 342,
348
+ "ɾj": 343,
349
+ "ɕ": 344,
350
+ "tsɯ": 345,
351
+
352
+ "ɐ": 346,
353
+ "ɑ": 347,
354
+ "ɒ": 348,
355
+ "ɜ": 349,
356
+ "ɫ": 350,
357
+ "ʑ": 351,
358
+ "ʲ": 352,
359
+
360
+ "y": 353,
361
+ "ø": 354,
362
+ "œ": 355,
363
+ "ʁ": 356,
364
+ "̃": 357,
365
+ "ɲ": 358,
366
+
367
+ ":": 359,
368
+ ";": 360,
369
+ "'": 361,
370
+ "…": 362
371
+ }
372
+ }
g2p/g2p_generation.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import sys
8
+
9
+ from g2p.g2p import PhonemeBpeTokenizer
10
+ from g2p.utils.g2p import phonemizer_g2p
11
+ import tqdm
12
+ from typing import List
13
+ import json
14
+ import os
15
+ import re
16
+
17
+
18
+ def ph_g2p(text, language):
19
+
20
+ return phonemizer_g2p(text=text, language=language)
21
+
22
+
23
+ def g2p(text, sentence, language):
24
+
25
+ return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
26
+
27
+
28
+ def is_chinese(char):
29
+ if char >= "\u4e00" and char <= "\u9fa5":
30
+ return True
31
+ else:
32
+ return False
33
+
34
+
35
+ def is_alphabet(char):
36
+ if (char >= "\u0041" and char <= "\u005a") or (
37
+ char >= "\u0061" and char <= "\u007a"
38
+ ):
39
+ return True
40
+ else:
41
+ return False
42
+
43
+
44
+ def is_other(char):
45
+ if not (is_chinese(char) or is_alphabet(char)):
46
+ return True
47
+ else:
48
+ return False
49
+
50
+
51
+ def get_segment(text: str) -> List[str]:
52
+ # sentence --> [ch_part, en_part, ch_part, ...]
53
+ segments = []
54
+ types = []
55
+ flag = 0
56
+ temp_seg = ""
57
+ temp_lang = ""
58
+
59
+ # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
60
+ for i, ch in enumerate(text):
61
+ if is_chinese(ch):
62
+ types.append("zh")
63
+ elif is_alphabet(ch):
64
+ types.append("en")
65
+ else:
66
+ types.append("other")
67
+
68
+ assert len(types) == len(text)
69
+
70
+ for i in range(len(types)):
71
+ # find the first char of the seg
72
+ if flag == 0:
73
+ temp_seg += text[i]
74
+ temp_lang = types[i]
75
+ flag = 1
76
+ else:
77
+ if temp_lang == "other":
78
+ if types[i] == temp_lang:
79
+ temp_seg += text[i]
80
+ else:
81
+ temp_seg += text[i]
82
+ temp_lang = types[i]
83
+ else:
84
+ if types[i] == temp_lang:
85
+ temp_seg += text[i]
86
+ elif types[i] == "other":
87
+ temp_seg += text[i]
88
+ else:
89
+ segments.append((temp_seg, temp_lang))
90
+ temp_seg = text[i]
91
+ temp_lang = types[i]
92
+ flag = 1
93
+
94
+ segments.append((temp_seg, temp_lang))
95
+ return segments
96
+
97
+
98
+ def chn_eng_g2p(text: str):
99
+ # now only en and ch
100
+ segments = get_segment(text)
101
+ all_phoneme = ""
102
+ all_tokens = []
103
+
104
+ for index in range(len(segments)):
105
+ seg = segments[index]
106
+ phoneme, token = g2p(seg[0], text, seg[1])
107
+ all_phoneme += phoneme + "|"
108
+ all_tokens += token
109
+
110
+ if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
111
+ all_phoneme = all_phoneme[:-2]
112
+ all_tokens = all_tokens[:-1]
113
+ return all_phoneme, all_tokens
114
+
115
+
116
+ vocab_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "g2p/vocab.json")
117
+ text_tokenizer = PhonemeBpeTokenizer(vacab_path=vocab_path)
118
+ with open(vocab_path, "r") as f:
119
+ json_data = f.read()
120
+ data = json.loads(json_data)
121
+ vocab = data["vocab"]
122
+
123
+ if __name__ == '__main__':
124
+ phone, token = chn_eng_g2p("你好,hello world")
125
+ phone, token = chn_eng_g2p("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑")
126
+ print(phone)
127
+ print(token)
128
+
129
+ #phone, token = text_tokenizer.tokenize("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑", "", "auto")
130
+ phone, token = text_tokenizer.tokenize("緑", "", "auto")
131
+ #phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "auto")
132
+ #phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "other")
133
+ print(phone)
134
+ print(token)
g2p/language_segmentation/LangSegment.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file bundles language identification functions.
3
+
4
+ Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
5
+
6
+ Original code: Copyright (c) 2011 Marco Lui <[email protected]>.
7
+ Based on research by Marco Lui and Tim Baldwin.
8
+
9
+ See LICENSE file for more info.
10
+ https://github.com/adbar/py3langid
11
+
12
+ Projects:
13
+ https://github.com/juntaosun/LangSegment
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import sys
19
+ import numpy as np
20
+ from collections import Counter
21
+ from collections import defaultdict
22
+
23
+ # import langid
24
+ # import py3langid as langid
25
+ # pip install py3langid==0.2.2
26
+
27
+ # 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
28
+ # langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
29
+ # For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
30
+ from py3langid.langid import LanguageIdentifier, MODEL_FILE
31
+
32
+ # Digital processing
33
+ try:from .utils.num import num2str
34
+ except ImportError:
35
+ try:from utils.num import num2str
36
+ except ImportError as e:
37
+ raise e
38
+
39
+ # -----------------------------------
40
+ # 更新日志:新版本分词更加精准。
41
+ # Changelog: The new version of the word segmentation is more accurate.
42
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
43
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
44
+ # -----------------------------------
45
+
46
+
47
+ # Word segmentation function:
48
+ # automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
49
+ # making it more suitable for TTS processing.
50
+ # This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
51
+ # This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
52
+
53
+ #===========================================================================================================
54
+ #分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
55
+ #このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
56
+ #===========================================================================================================
57
+ #(1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
58
+ #(2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
59
+ #この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
60
+ #===========================================================================================================
61
+
62
+ #===========================================================================================================
63
+ # 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
64
+ # 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
65
+ #===========================================================================================================
66
+ # (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
67
+ # (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
68
+ # 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
69
+ #===========================================================================================================
70
+
71
+ # ===========================================================================================================
72
+ # 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
73
+ # 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
74
+ # ===========================================================================================================
75
+ # (1)自动分词:“韩语中的오빠读什么呢?���なたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
76
+ # (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
77
+ # 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
78
+ # ===========================================================================================================
79
+
80
+
81
+ # 手动分词标签规范:<语言标签>文本内容</语言标签>
82
+ # 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
83
+ # Manual word segmentation tag specification: <language tags> text content </language tags>
84
+ # 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
85
+ # ===========================================================================================================
86
+ # For manual word segmentation, labels need to appear in pairs, such as:
87
+ # 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
88
+ # 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
89
+ # Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
90
+ # ===========================================================================================================
91
+
92
+
93
+ # ===========================================================================================================
94
+ # 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
95
+ # 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
96
+ # 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
97
+ # Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
98
+ # ===========================================================================================================
99
+ # 中文实现:Chinese implementation:
100
+ # 【SSML】<number>=中文大写数字读法(单字)
101
+ # 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
102
+ # 【SSML】<currency>=按金额发音。
103
+ # 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
104
+ # ===========================================================================================================
105
+ class LangSSML:
106
+
107
+ def __init__(self):
108
+ # 纯数字
109
+ self._zh_numerals_number = {
110
+ '0': '零',
111
+ '1': '一',
112
+ '2': '二',
113
+ '3': '三',
114
+ '4': '四',
115
+ '5': '五',
116
+ '6': '六',
117
+ '7': '七',
118
+ '8': '八',
119
+ '9': '九'
120
+ }
121
+
122
+ # 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
123
+ # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
124
+ def _format_chinese_data(self, date_str:str):
125
+ # 处理日期格式
126
+ input_date = date_str
127
+ if date_str is None or date_str.strip() == "":return ""
128
+ date_str = re.sub(r"[\/\._|年|月]","-",date_str)
129
+ date_str = re.sub(r"日",r"",date_str)
130
+ date_arrs = date_str.split(' ')
131
+ if len(date_arrs) == 1 and ":" in date_arrs[0]:
132
+ time_str = date_arrs[0]
133
+ date_arrs = []
134
+ else:
135
+ time_str = date_arrs[1] if len(date_arrs) >=2 else ""
136
+ def nonZero(num,cn,func=None):
137
+ if func is not None:num=func(num)
138
+ return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
139
+ f_number = self.to_chinese_number
140
+ f_currency = self.to_chinese_currency
141
+ # year, month, day
142
+ year_month_day = ""
143
+ if len(date_arrs) > 0:
144
+ year, month, day = "","",""
145
+ parts = date_arrs[0].split('-')
146
+ if len(parts) == 3: # 格式为 YYYY-MM-DD
147
+ year, month, day = parts
148
+ elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
149
+ if len(parts[0]) == 4: # 年-月
150
+ year, month = parts
151
+ else:month, day = parts # 月-日
152
+ elif len(parts[0]) > 0: # 仅有月-日或年
153
+ if len(parts[0]) == 4:
154
+ year = parts[0]
155
+ else:day = parts[0]
156
+ year,month,day = nonZero(year,"年",f_number),nonZero(month,"月",f_currency),nonZero(day,"日",f_currency)
157
+ year_month_day = re.sub(r"([年|月|日])+",r"\1",f"{year}{month}{day}")
158
+ # hours, minutes, seconds
159
+ time_str = re.sub(r"[\/\.\-:_]",":",time_str)
160
+ time_arrs = time_str.split(":")
161
+ hours, minutes, seconds = "","",""
162
+ if len(time_arrs) == 3: # H/M/S
163
+ hours, minutes, seconds = time_arrs
164
+ elif len(time_arrs) == 2:# H/M
165
+ hours, minutes = time_arrs
166
+ elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}点' # H
167
+ if len(time_arrs) > 1:
168
+ hours, minutes, seconds = nonZero(hours,"点",f_currency),nonZero(minutes,"分",f_currency),nonZero(seconds,"秒",f_currency)
169
+ hours_minutes_seconds = re.sub(r"([点|分|秒])+",r"\1",f"{hours}{minutes}{seconds}")
170
+ output_date = f"{year_month_day}{hours_minutes_seconds}"
171
+ return output_date
172
+
173
+ # 【SSML】number=中文大写数字读法(单字)
174
+ # Chinese Numbers(single word)
175
+ def to_chinese_number(self, num:str):
176
+ pattern = r'(\d+)'
177
+ zh_numerals = self._zh_numerals_number
178
+ arrs = re.split(pattern, num)
179
+ output = ""
180
+ for item in arrs:
181
+ if re.match(pattern,item):
182
+ output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
183
+ else:output += item
184
+ output = output.replace(".","点")
185
+ return output
186
+
187
+ # 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
188
+ # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
189
+ def to_chinese_telephone(self, num:str):
190
+ output = self.to_chinese_number(num.replace("+86","")) # zh +86
191
+ output = output.replace("一","幺")
192
+ return output
193
+
194
+ # 【SSML】currency=按金额发音。
195
+ # Digital processing from GPT_SoVITS num.py (thanks)
196
+ def to_chinese_currency(self, num:str):
197
+ pattern = r'(\d+)'
198
+ arrs = re.split(pattern, num)
199
+ output = ""
200
+ for item in arrs:
201
+ if re.match(pattern,item):
202
+ output += num2str(item)
203
+ else:output += item
204
+ output = output.replace(".","点")
205
+ return output
206
+
207
+ # 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
208
+ def to_chinese_date(self, num:str):
209
+ chinese_date = self._format_chinese_data(num)
210
+ return chinese_date
211
+
212
+
213
+ class LangSegment:
214
+
215
+ def __init__(self):
216
+
217
+ self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
218
+
219
+ self._text_cache = None
220
+ self._text_lasts = None
221
+ self._text_langs = None
222
+ self._lang_count = None
223
+ self._lang_eos = None
224
+
225
+ # 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
226
+ # Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
227
+ # <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
228
+ self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
229
+
230
+ # 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
231
+ # 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
232
+ # 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
233
+ # The language filter group function allows you to specify reserved languages.
234
+ # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
235
+ # 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
236
+
237
+ # 系统默认过滤器。System default filter。(ISO 639-1 codes given)
238
+ # ----------------------------------------------------------------------------------------------------------------------------------
239
+ # "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
240
+ # "th"泰语=Thai
241
+ # ----------------------------------------------------------------------------------------------------------------------------------
242
+ self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
243
+
244
+ # 用户可自定义过滤器。User-defined filters
245
+ self.Langfilters = self.DEFAULT_FILTERS[:] # 创建副本
246
+
247
+ # 合并文本
248
+ self.isLangMerge = True
249
+
250
+ # 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
251
+ # 请使用API启用:self.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
252
+
253
+ # 预览版功能,自动启用或禁用,无需设置
254
+ # Preview feature, automatically enabled or disabled, no settings required
255
+ self.EnablePreview = False
256
+
257
+ # 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
258
+ # In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
259
+ # 示例:您可以任意指定多种组合,进行过滤
260
+ # Example: You can specify any combination to filter
261
+
262
+ # 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
263
+ # 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
264
+ # 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
265
+ # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
266
+ # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
267
+ self.LangPriorityThreshold = 0.89
268
+
269
+ # Langfilters = ["zh"] # 按中文识别
270
+ # Langfilters = ["en"] # 按英文识别
271
+ # Langfilters = ["ja"] # 按日文识别
272
+ # Langfilters = ["ko"] # 按韩文识别
273
+ # Langfilters = ["zh_ja"] # 中日混合识别
274
+ # Langfilters = ["zh_en"] # 中英混合识别
275
+ # Langfilters = ["ja_en"] # 日英混合识别
276
+ # Langfilters = ["zh_ko"] # 中韩混合识别
277
+ # Langfilters = ["ja_ko"] # 日韩混合识别
278
+ # Langfilters = ["en_ko"] # 英韩混合识别
279
+ # Langfilters = ["zh_ja_en"] # 中日英混合识别
280
+ # Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
281
+
282
+ # 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
283
+ # より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
284
+
285
+ # 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
286
+ # 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
287
+ self.keepPinyin = False
288
+
289
+ # DEFINITION
290
+ self.PARSE_TAG = re.compile(r'(⑥\$*\d+[\d]{6,}⑥)')
291
+
292
+ self.LangSSML = LangSSML()
293
+
294
+ def _clears(self):
295
+ self._text_cache = None
296
+ self._text_lasts = None
297
+ self._text_langs = None
298
+ self._text_waits = None
299
+ self._lang_count = None
300
+ self._lang_eos = None
301
+
302
+ def _is_english_word(self, word):
303
+ return bool(re.match(r'^[a-zA-Z]+$', word))
304
+
305
+ def _is_chinese(self, word):
306
+ for char in word:
307
+ if '\u4e00' <= char <= '\u9fff':
308
+ return True
309
+ return False
310
+
311
+ def _is_japanese_kana(self, word):
312
+ pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
313
+ matches = pattern.findall(word)
314
+ return len(matches) > 0
315
+
316
+ def _insert_english_uppercase(self, word):
317
+ modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
318
+ modified_text = modified_text.strip('-')
319
+ return modified_text + " "
320
+
321
+ def _split_camel_case(self, word):
322
+ return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
323
+
324
+ def _statistics(self, language, text):
325
+ # Language word statistics:
326
+ # Chinese characters usually occupy double bytes
327
+ if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
328
+ self._lang_count = defaultdict(int)
329
+ lang_count = self._lang_count
330
+ if not "|" in language:
331
+ lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
332
+ self._lang_count = lang_count
333
+
334
+ def _clear_text_number(self, text):
335
+ if text == "\n":return text,False # Keep Line Breaks
336
+ clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
337
+ is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
338
+ return clear_text,is_number
339
+
340
+ def _saveData(self, words,language:str,text:str,score:float,symbol=None):
341
+ # Pre-detection
342
+ clear_text , is_number = self._clear_text_number(text)
343
+ # Merge the same language and save the results
344
+ preData = words[-1] if len(words) > 0 else None
345
+ if symbol is not None:pass
346
+ elif preData is not None and preData["symbol"] is None:
347
+ if len(clear_text) == 0:language = preData["lang"]
348
+ elif is_number == True:language = preData["lang"]
349
+ _ , pre_is_number = self._clear_text_number(preData["text"])
350
+ if (preData["lang"] == language):
351
+ self._statistics(preData["lang"],text)
352
+ text = preData["text"] + text
353
+ preData["text"] = text
354
+ return preData
355
+ elif pre_is_number == True:
356
+ text = f'{preData["text"]}{text}'
357
+ words.pop()
358
+ elif is_number == True:
359
+ priority_language = self._get_filters_string()[:2]
360
+ if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
361
+ data = {"lang":language,"text": text,"score":score,"symbol":symbol}
362
+ filters = self.Langfilters
363
+ if filters is None or len(filters) == 0 or "?" in language or \
364
+ language in filters or language in filters[0] or \
365
+ filters[0] == "*" or filters[0] in "alls-mixs-autos":
366
+ words.append(data)
367
+ self._statistics(data["lang"],data["text"])
368
+ return data
369
+
370
+ def _addwords(self, words,language,text,score,symbol=None):
371
+ if text == "\n":pass # Keep Line Breaks
372
+ elif text is None or len(text.strip()) == 0:return True
373
+ if language is None:language = ""
374
+ language = language.lower()
375
+ if language == 'en':text = self._insert_english_uppercase(text)
376
+ # text = re.sub(r'[(())]', ',' , text) # Keep it.
377
+ text_waits = self._text_waits
378
+ ispre_waits = len(text_waits)>0
379
+ preResult = text_waits.pop() if ispre_waits else None
380
+ if preResult is None:preResult = words[-1] if len(words) > 0 else None
381
+ if preResult and ("|" in preResult["lang"]):
382
+ pre_lang = preResult["lang"]
383
+ if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
384
+ else:preResult["lang"]=pre_lang.split("|")[0]
385
+ if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
386
+ pre_lang = preResult["lang"] if preResult else None
387
+ if ("|" in language) and (pre_lang and not pre_lang in language and not "…" in language):language = language.split("|")[0]
388
+ if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
389
+ else:self._saveData(words,language,text,score,symbol)
390
+ return False
391
+
392
+ def _get_prev_data(self, words):
393
+ data = words[-1] if words and len(words) > 0 else None
394
+ if data:return (data["lang"] , data["text"])
395
+ return (None,"")
396
+
397
+ def _match_ending(self, input , index):
398
+ if input is None or len(input) == 0:return False,None
399
+ input = re.sub(r'\s+', '', input)
400
+ if len(input) == 0 or abs(index) > len(input):return False,None
401
+ ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
402
+ return ending_pattern.match(input[index]),input[index]
403
+
404
+ def _cleans_text(self, cleans_text):
405
+ cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
406
+ cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
407
+ return cleans_text.strip()
408
+
409
+ def _mean_processing(self, text:str):
410
+ if text is None or (text.strip()) == "":return None , 0.0
411
+ arrs = self._split_camel_case(text).split(" ")
412
+ langs = []
413
+ for t in arrs:
414
+ if len(t.strip()) <= 3:continue
415
+ language, score = self.langid.classify(t)
416
+ langs.append({"lang":language})
417
+ if len(langs) == 0:return None , 0.0
418
+ return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
419
+
420
+ def _lang_classify(self, cleans_text):
421
+ language, score = self.langid.classify(cleans_text)
422
+ # fix: Huggingface is np.float32
423
+ if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
424
+ score = score.item()
425
+ score = round(score , 3)
426
+ return language, score
427
+
428
+ def _get_filters_string(self):
429
+ filters = self.Langfilters
430
+ return "-".join(filters).lower().strip() if filters is not None else ""
431
+
432
+ def _parse_language(self, words , segment):
433
+ LANG_JA = "ja"
434
+ LANG_ZH = "zh"
435
+ LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
436
+ LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
437
+ language = LANG_ZH
438
+ regex_pattern = re.compile(r'([^\w\s]+)')
439
+ lines = regex_pattern.split(segment)
440
+ lines_max = len(lines)
441
+ LANG_EOS =self._lang_eos
442
+ for index, text in enumerate(lines):
443
+ if len(text) == 0:continue
444
+ EOS = index >= (lines_max - 1)
445
+ nextId = index + 1
446
+ nextText = lines[nextId] if not EOS else ""
447
+ nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
448
+ textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
449
+ if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
450
+ lines[nextId] = f'{text}{nextText}'
451
+ continue
452
+ number_tags = re.compile(r'(⑥\d{6,}⑥)')
453
+ cleans_text = re.sub(number_tags, '' ,text)
454
+ cleans_text = re.sub(r'\d+', '' ,cleans_text)
455
+ cleans_text = self._cleans_text(cleans_text)
456
+ # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
457
+ if not EOS and len(cleans_text) <= 2:
458
+ lines[nextId] = f'{text}{nextText}'
459
+ continue
460
+ language,score = self._lang_classify(cleans_text)
461
+ prev_language , prev_text = self._get_prev_data(words)
462
+ if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
463
+ if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
464
+ filters_string = self._get_filters_string()
465
+ if score < self.LangPriorityThreshold and len(filters_string) > 0:
466
+ index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
467
+ if index_ja != -1 and index_ja < index_zh:language = LANG_JA
468
+ elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
469
+ if self._is_japanese_kana(cleans_text):language = LANG_JA
470
+ elif len(cleans_text) > 2 and score > 0.90:pass
471
+ elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
472
+ else:
473
+ LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
474
+ match_end,match_char = self._match_ending(text, -1)
475
+ referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
476
+ if match_char in "。.": language = prev_language if referen and len(words) > 0 else language
477
+ else:language = f"{LANG_UNKNOWN}|…"
478
+ text,*_ = re.subn(number_tags , self._restore_number , text )
479
+ self._addwords(words,language,text,score)
480
+
481
+ # ----------------------------------------------------------
482
+ # 【SSML】中文数字处理:Chinese Number Processing (SSML support)
483
+ # 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
484
+ # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
485
+ # 中文电话号码:<telephone>1234567</telephone>
486
+ # 中文数字号码:<number>1234567</number>
487
+ def _process_symbol_SSML(self, words,data):
488
+ tag , match = data
489
+ language = SSML = match[1]
490
+ text = match[2]
491
+ score = 1.0
492
+ if SSML == "telephone":
493
+ # 中文-电话号码
494
+ language = "zh"
495
+ text = self.LangSSML.to_chinese_telephone(text)
496
+ elif SSML == "number":
497
+ # 中文-数字读法
498
+ language = "zh"
499
+ text = self.LangSSML.to_chinese_number(text)
500
+ elif SSML == "currency":
501
+ # 中文-按金额发音
502
+ language = "zh"
503
+ text = self.LangSSML.to_chinese_currency(text)
504
+ elif SSML == "date":
505
+ # 中文-按金额发音
506
+ language = "zh"
507
+ text = self.LangSSML.to_chinese_date(text)
508
+ self._addwords(words,language,text,score,SSML)
509
+
510
+ # ----------------------------------------------------------
511
+ def _restore_number(self, matche):
512
+ value = matche.group(0)
513
+ text_cache = self._text_cache
514
+ if value in text_cache:
515
+ process , data = text_cache[value]
516
+ tag , match = data
517
+ value = match
518
+ return value
519
+
520
+ def _pattern_symbols(self, item , text):
521
+ if text is None:return text
522
+ tag , pattern , process = item
523
+ matches = pattern.findall(text)
524
+ if len(matches) == 1 and "".join(matches[0]) == text:
525
+ return text
526
+ for i , match in enumerate(matches):
527
+ key = f"⑥{tag}{i:06d}⑥"
528
+ text = re.sub(pattern , key , text , count=1)
529
+ self._text_cache[key] = (process , (tag , match))
530
+ return text
531
+
532
+ def _process_symbol(self, words,data):
533
+ tag , match = data
534
+ language = match[1]
535
+ text = match[2]
536
+ score = 1.0
537
+ filters = self._get_filters_string()
538
+ if language not in filters:
539
+ self._process_symbol_SSML(words,data)
540
+ else:
541
+ self._addwords(words,language,text,score,True)
542
+
543
+ def _process_english(self, words,data):
544
+ tag , match = data
545
+ text = match[0]
546
+ filters = self._get_filters_string()
547
+ priority_language = filters[:2]
548
+ # Preview feature, other language segmentation processing
549
+ enablePreview = self.EnablePreview
550
+ if enablePreview == True:
551
+ # Experimental: Other language support
552
+ regex_pattern = re.compile(r'(.*?[。.??!!]+[\n]{,1})')
553
+ lines = regex_pattern.split(text)
554
+ for index , text in enumerate(lines):
555
+ if len(text.strip()) == 0:continue
556
+ cleans_text = self._cleans_text(text)
557
+ language,score = self._lang_classify(cleans_text)
558
+ if language not in filters:
559
+ language,score = self._mean_processing(cleans_text)
560
+ if language is None or score <= 0.0:continue
561
+ elif language in filters:pass # pass
562
+ elif score >= 0.95:continue # High score, but not in the filter, excluded.
563
+ elif score <= 0.15 and filters[:2] == "fr":language = priority_language
564
+ else:language = "en"
565
+ self._addwords(words,language,text,score)
566
+ else:
567
+ # Default is English
568
+ language, score = "en", 1.0
569
+ self._addwords(words,language,text,score)
570
+
571
+ def _process_Russian(self, words,data):
572
+ tag , match = data
573
+ text = match[0]
574
+ language = "ru"
575
+ score = 1.0
576
+ self._addwords(words,language,text,score)
577
+
578
+ def _process_Thai(self, words,data):
579
+ tag , match = data
580
+ text = match[0]
581
+ language = "th"
582
+ score = 1.0
583
+ self._addwords(words,language,text,score)
584
+
585
+ def _process_korean(self, words,data):
586
+ tag , match = data
587
+ text = match[0]
588
+ language = "ko"
589
+ score = 1.0
590
+ self._addwords(words,language,text,score)
591
+
592
+ def _process_quotes(self, words,data):
593
+ tag , match = data
594
+ text = "".join(match)
595
+ childs = self.PARSE_TAG.findall(text)
596
+ if len(childs) > 0:
597
+ self._process_tags(words , text , False)
598
+ else:
599
+ cleans_text = self._cleans_text(match[1])
600
+ if len(cleans_text) <= 5:
601
+ self._parse_language(words,text)
602
+ else:
603
+ language,score = self._lang_classify(cleans_text)
604
+ self._addwords(words,language,text,score)
605
+
606
+ def _process_pinyin(self, words,data):
607
+ tag , match = data
608
+ text = match
609
+ language = "zh"
610
+ score = 1.0
611
+ self._addwords(words,language,text,score)
612
+
613
+ def _process_number(self, words,data): # "$0" process only
614
+ """
615
+ Numbers alone cannot accurately identify language.
616
+ Because numbers are universal in all languages.
617
+ So it won't be executed here, just for testing.
618
+ """
619
+ tag , match = data
620
+ language = words[0]["lang"] if len(words) > 0 else "zh"
621
+ text = match
622
+ score = 0.0
623
+ self._addwords(words,language,text,score)
624
+
625
+ def _process_tags(self, words , text , root_tag):
626
+ text_cache = self._text_cache
627
+ segments = re.split(self.PARSE_TAG, text)
628
+ segments_len = len(segments) - 1
629
+ for index , text in enumerate(segments):
630
+ if root_tag:self._lang_eos = index >= segments_len
631
+ if self.PARSE_TAG.match(text):
632
+ process , data = text_cache[text]
633
+ if process:process(words , data)
634
+ else:
635
+ self._parse_language(words , text)
636
+ return words
637
+
638
+ def _merge_results(self, words):
639
+ new_word = []
640
+ for index , cur_data in enumerate(words):
641
+ if "symbol" in cur_data:del cur_data["symbol"]
642
+ if index == 0:new_word.append(cur_data)
643
+ else:
644
+ pre_data = new_word[-1]
645
+ if cur_data["lang"] == pre_data["lang"]:
646
+ pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
647
+ else:new_word.append(cur_data)
648
+ return new_word
649
+
650
+ def _parse_symbols(self, text):
651
+ TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
652
+ TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
653
+ TAG_BASE = re.compile(fr'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
654
+ # Get custom language filter
655
+ filters = self.Langfilters
656
+ filters = filters if filters is not None else ""
657
+ # =======================================================================================================
658
+ # Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
659
+ # 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
660
+ # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
661
+ # S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
662
+ # Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
663
+ # -------------------------------------------------------------------------------------------------------
664
+ # Preview feature, other language support
665
+ enablePreview = self.EnablePreview
666
+ if "fr" in filters or \
667
+ "vi" in filters:enablePreview = True
668
+ self.EnablePreview = enablePreview
669
+ # 实验性:法语字符支持。Prise en charge des caractères français
670
+ RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
671
+ # 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
672
+ RE_VI = "" if not enablePreview else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
673
+ # -------------------------------------------------------------------------------------------------------
674
+ # Basic options:
675
+ process_list = [
676
+ ( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
677
+ ( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
678
+ ( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
679
+ ( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'А-Яа-яЁё',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
680
+ ( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
681
+ ( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
682
+ ( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
683
+ ( TAG_P2 , re.compile(r'([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
684
+ ]
685
+ # Extended options: Default False
686
+ if self.keepPinyin == True:process_list.insert(1 ,
687
+ ( TAG_S2 , re.compile(r'([\(({](?:\s*\w*\d\w*\s*)+[})\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
688
+ )
689
+ # -------------------------------------------------------------------------------------------------------
690
+ words = []
691
+ lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
692
+ for index , text in enumerate(lines):
693
+ if len(text.strip()) == 0:continue
694
+ self._lang_eos = False
695
+ self._text_cache = {}
696
+ for item in process_list:
697
+ text = self._pattern_symbols(item , text)
698
+ cur_word = self._process_tags([] , text , True)
699
+ if len(cur_word) == 0:continue
700
+ cur_data = cur_word[0] if len(cur_word) > 0 else None
701
+ pre_data = words[-1] if len(words) > 0 else None
702
+ if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] and cur_data["symbol"] == False and pre_data["symbol"] :
703
+ cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
704
+ words.pop()
705
+ words += cur_word
706
+ if self.isLangMerge == True:words = self._merge_results(words)
707
+ lang_count = self._lang_count
708
+ if lang_count and len(lang_count) > 0:
709
+ lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
710
+ lang_count = list(lang_count.items())
711
+ self._lang_count = lang_count
712
+ return words
713
+
714
+ def setfilters(self, filters):
715
+ # 当过滤器更改时,清除缓存
716
+ # 필터가 변경되면 캐시를 지웁니다.
717
+ # フィルタが変更されると、キャッシュがクリアされます
718
+ # When the filter changes, clear the cache
719
+ if self.Langfilters != filters:
720
+ self._clears()
721
+ self.Langfilters = filters
722
+
723
+ def getfilters(self):
724
+ return self.Langfilters
725
+
726
+ def setPriorityThreshold(self, threshold:float):
727
+ self.LangPriorityThreshold = threshold
728
+
729
+ def getPriorityThreshold(self):
730
+ return self.LangPriorityThreshold
731
+
732
+ def getCounts(self):
733
+ lang_count = self._lang_count
734
+ if lang_count is not None:return lang_count
735
+ text_langs = self._text_langs
736
+ if text_langs is None or len(text_langs) == 0:return [("zh",0)]
737
+ lang_counts = defaultdict(int)
738
+ for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
739
+ lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
740
+ lang_counts = list(lang_counts.items())
741
+ self._lang_count = lang_counts
742
+ return lang_counts
743
+
744
+ def getTexts(self, text:str):
745
+ if text is None or len(text.strip()) == 0:
746
+ self._clears()
747
+ return []
748
+ # lasts
749
+ text_langs = self._text_langs
750
+ if self._text_lasts == text and text_langs is not None:return text_langs
751
+ # parse
752
+ self._text_waits = []
753
+ self._lang_count = None
754
+ self._text_lasts = text
755
+ text = self._parse_symbols(text)
756
+ self._text_langs = text
757
+ return text
758
+
759
+ def classify(self, text:str):
760
+ return self.getTexts(text)
761
+
762
+ def printList(langlist):
763
+ """
764
+ 功能:打印数组结果
765
+ 기능: 어레이 결과 인쇄
766
+ 機能:配列結果を印刷
767
+ Function: Print array results
768
+ """
769
+ print("\n===================【打印结果】===================")
770
+ if langlist is None or len(langlist) == 0:
771
+ print("无内容结果,No content result")
772
+ return
773
+ for line in langlist:
774
+ print(line)
775
+ pass
776
+
777
+
778
+
779
+ def main():
780
+
781
+ # -----------------------------------
782
+ # 更新日志:新版本分词更加精准。
783
+ # Changelog: The new version of the word segmentation is more accurate.
784
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
785
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
786
+ # -----------------------------------
787
+
788
+ # 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
789
+ # text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
790
+
791
+ # 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
792
+ # text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
793
+
794
+ # 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
795
+ # text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
796
+
797
+
798
+ # 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
799
+ # text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
800
+
801
+
802
+ # 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
803
+ langsegment = LangSegment()
804
+ langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
805
+ text = """
806
+ 我喜欢在雨天里听音乐。
807
+ I enjoy listening to music on rainy days.
808
+ 雨の日に音楽を聴くのが好きです。
809
+ 비 오는 날에 음악을 듣는 것을 즐깁니다。
810
+ J'aime écouter de la musique les jours de pluie.
811
+ Tôi thích nghe nhạc vào những ngày mưa.
812
+ Мне нравится слушать музыку в дождливую погоду.
813
+ ฉันชอบฟังเพลงในวันที่ฝนตก
814
+ """
815
+
816
+
817
+
818
+ # 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
819
+ langlist = langsegment.getTexts(text)
820
+ printList(langlist)
821
+
822
+
823
+ # 语种统计:Language statistics:
824
+ print("\n===================【语种统计】===================")
825
+ # 获取所有语种数组结果,根据内容字数降序排列
826
+ # Get the array results in all languages, sorted in descending order according to the number of content words
827
+ langCounts = langsegment.getCounts()
828
+ print(langCounts , "\n")
829
+
830
+ # 根据结果获取内容的主要语种 (语言,字数含标点)
831
+ # Get the main language of content based on the results (language, word count including punctuation)
832
+ lang , count = langCounts[0]
833
+ print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
834
+ print("==================================================\n")
835
+
836
+
837
+ # 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
838
+ # ===================【打印结果】===================
839
+ # {'lang': 'zh', 'text': '你的名字叫'}
840
+ # {'lang': 'ja', 'text': '佐々木?'}
841
+ # {'lang': 'zh', 'text': '吗?韩语中的'}
842
+ # {'lang': 'ko', 'text': '안녕 오빠'}
843
+ # {'lang': 'zh', 'text': '读什么呢?'}
844
+ # {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
845
+ # {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
846
+ # {'lang': 'en', 'text': 'i Phone '}
847
+ # {'lang': 'zh', 'text': '15系列机型和三款'}
848
+ # {'lang': 'en', 'text': 'Apple Watch '}
849
+ # {'lang': 'zh', 'text': '等一系列新品,这次的'}
850
+ # {'lang': 'en', 'text': 'i Pad Air '}
851
+ # {'lang': 'zh', 'text': '采用了'}
852
+ # {'lang': 'en', 'text': 'L C D '}
853
+ # {'lang': 'zh', 'text': '屏幕'}
854
+ # ===================【语种统计】===================
855
+
856
+ # ===================【语种统计】===================
857
+ # [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
858
+
859
+ # 输入内容的主要语言为 = zh ,字数 = 51
860
+ # ==================================================
861
+ # The main language of the input content is = zh, word count = 51
862
+
863
+
864
+ if __name__ == "__main__":
865
+ main()
g2p/language_segmentation/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .LangSegment import LangSegment
2
+
3
+
4
+ # release
5
+ __version__ = '0.3.5'
6
+
7
+
8
+ # develop
9
+ __develop__ = 'dev-0.0.1'
g2p/language_segmentation/utils/__init__.py ADDED
File without changes
g2p/language_segmentation/utils/num.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Digital processing from GPT_SoVITS num.py (thanks)
15
+ """
16
+ Rules to verbalize numbers into Chinese characters.
17
+ https://zh.wikipedia.org/wiki/中文数字#現代中文
18
+ """
19
+
20
+ import re
21
+ from collections import OrderedDict
22
+ from typing import List
23
+
24
+ DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
25
+ UNITS = OrderedDict({
26
+ 1: '十',
27
+ 2: '百',
28
+ 3: '千',
29
+ 4: '万',
30
+ 8: '亿',
31
+ })
32
+
33
+ COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
34
+
35
+ # 分数表达式
36
+ RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
37
+
38
+
39
+ def replace_frac(match) -> str:
40
+ """
41
+ Args:
42
+ match (re.Match)
43
+ Returns:
44
+ str
45
+ """
46
+ sign = match.group(1)
47
+ nominator = match.group(2)
48
+ denominator = match.group(3)
49
+ sign: str = "负" if sign else ""
50
+ nominator: str = num2str(nominator)
51
+ denominator: str = num2str(denominator)
52
+ result = f"{sign}{denominator}分之{nominator}"
53
+ return result
54
+
55
+
56
+ # 百分数表达式
57
+ RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
58
+
59
+
60
+ def replace_percentage(match) -> str:
61
+ """
62
+ Args:
63
+ match (re.Match)
64
+ Returns:
65
+ str
66
+ """
67
+ sign = match.group(1)
68
+ percent = match.group(2)
69
+ sign: str = "负" if sign else ""
70
+ percent: str = num2str(percent)
71
+ result = f"{sign}百分之{percent}"
72
+ return result
73
+
74
+
75
+ # 整数表达式
76
+ # 带负号的整数 -10
77
+ RE_INTEGER = re.compile(r'(-)' r'(\d+)')
78
+
79
+
80
+ def replace_negative_num(match) -> str:
81
+ """
82
+ Args:
83
+ match (re.Match)
84
+ Returns:
85
+ str
86
+ """
87
+ sign = match.group(1)
88
+ number = match.group(2)
89
+ sign: str = "负" if sign else ""
90
+ number: str = num2str(number)
91
+ result = f"{sign}{number}"
92
+ return result
93
+
94
+
95
+ # 编号-无符号整形
96
+ # 00078
97
+ RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
98
+
99
+
100
+ def replace_default_num(match):
101
+ """
102
+ Args:
103
+ match (re.Match)
104
+ Returns:
105
+ str
106
+ """
107
+ number = match.group(0)
108
+ return verbalize_digit(number, alt_one=True)
109
+
110
+
111
+ # 加减乘除
112
+ # RE_ASMD = re.compile(
113
+ # r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
114
+ RE_ASMD = re.compile(
115
+ r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
116
+
117
+ asmd_map = {
118
+ '+': '加',
119
+ '-': '减',
120
+ '×': '乘',
121
+ '÷': '除',
122
+ '=': '等于'
123
+ }
124
+
125
+ def replace_asmd(match) -> str:
126
+ """
127
+ Args:
128
+ match (re.Match)
129
+ Returns:
130
+ str
131
+ """
132
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
133
+ return result
134
+
135
+
136
+ # 次方专项
137
+ RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
138
+
139
+ power_map = {
140
+ '⁰': '0',
141
+ '¹': '1',
142
+ '²': '2',
143
+ '³': '3',
144
+ '⁴': '4',
145
+ '⁵': '5',
146
+ '⁶': '6',
147
+ '⁷': '7',
148
+ '⁸': '8',
149
+ '⁹': '9',
150
+ 'ˣ': 'x',
151
+ 'ʸ': 'y',
152
+ 'ⁿ': 'n'
153
+ }
154
+
155
+ def replace_power(match) -> str:
156
+ """
157
+ Args:
158
+ match (re.Match)
159
+ Returns:
160
+ str
161
+ """
162
+ power_num = ""
163
+ for m in match.group(0):
164
+ power_num += power_map[m]
165
+ result = "的" + power_num + "次方"
166
+ return result
167
+
168
+
169
+ # 数字表达式
170
+ # 纯小数
171
+ RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
172
+ # 正整数 + 量词
173
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
174
+ RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
175
+
176
+
177
+ def replace_positive_quantifier(match) -> str:
178
+ """
179
+ Args:
180
+ match (re.Match)
181
+ Returns:
182
+ str
183
+ """
184
+ number = match.group(1)
185
+ match_2 = match.group(2)
186
+ if match_2 == "+":
187
+ match_2 = "多"
188
+ match_2: str = match_2 if match_2 else ""
189
+ quantifiers: str = match.group(3)
190
+ number: str = num2str(number)
191
+ result = f"{number}{match_2}{quantifiers}"
192
+ return result
193
+
194
+
195
+ def replace_number(match) -> str:
196
+ """
197
+ Args:
198
+ match (re.Match)
199
+ Returns:
200
+ str
201
+ """
202
+ sign = match.group(1)
203
+ number = match.group(2)
204
+ pure_decimal = match.group(5)
205
+ if pure_decimal:
206
+ result = num2str(pure_decimal)
207
+ else:
208
+ sign: str = "负" if sign else ""
209
+ number: str = num2str(number)
210
+ result = f"{sign}{number}"
211
+ return result
212
+
213
+
214
+ # 范围表达式
215
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
216
+
217
+ RE_RANGE = re.compile(
218
+ r"""
219
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
220
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
221
+ [-~] # 匹配范围分隔符
222
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
223
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
224
+ """, re.VERBOSE)
225
+
226
+
227
+ def replace_range(match) -> str:
228
+ """
229
+ Args:
230
+ match (re.Match)
231
+ Returns:
232
+ str
233
+ """
234
+ first, second = match.group(1), match.group(6)
235
+ first = RE_NUMBER.sub(replace_number, first)
236
+ second = RE_NUMBER.sub(replace_number, second)
237
+ result = f"{first}到{second}"
238
+ return result
239
+
240
+
241
+ # ~至表达式
242
+ RE_TO_RANGE = re.compile(
243
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
244
+
245
+ def replace_to_range(match) -> str:
246
+ """
247
+ Args:
248
+ match (re.Match)
249
+ Returns:
250
+ str
251
+ """
252
+ result = match.group(0).replace('~', '至')
253
+ return result
254
+
255
+
256
+ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
257
+ stripped = value_string.lstrip('0')
258
+ if len(stripped) == 0:
259
+ return []
260
+ elif len(stripped) == 1:
261
+ if use_zero and len(stripped) < len(value_string):
262
+ return [DIGITS['0'], DIGITS[stripped]]
263
+ else:
264
+ return [DIGITS[stripped]]
265
+ else:
266
+ largest_unit = next(
267
+ power for power in reversed(UNITS.keys()) if power < len(stripped))
268
+ first_part = value_string[:-largest_unit]
269
+ second_part = value_string[-largest_unit:]
270
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
271
+ second_part)
272
+
273
+
274
+ def verbalize_cardinal(value_string: str) -> str:
275
+ if not value_string:
276
+ return ''
277
+
278
+ # 000 -> '零' , 0 -> '零'
279
+ value_string = value_string.lstrip('0')
280
+ if len(value_string) == 0:
281
+ return DIGITS['0']
282
+
283
+ result_symbols = _get_value(value_string)
284
+ # verbalized number starting with '一十*' is abbreviated as `十*`
285
+ if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
286
+ '1'] and result_symbols[1] == UNITS[1]:
287
+ result_symbols = result_symbols[1:]
288
+ return ''.join(result_symbols)
289
+
290
+
291
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
292
+ result_symbols = [DIGITS[digit] for digit in value_string]
293
+ result = ''.join(result_symbols)
294
+ if alt_one:
295
+ result = result.replace("一", "幺")
296
+ return result
297
+
298
+
299
+ def num2str(value_string: str) -> str:
300
+ integer_decimal = value_string.split('.')
301
+ if len(integer_decimal) == 1:
302
+ integer = integer_decimal[0]
303
+ decimal = ''
304
+ elif len(integer_decimal) == 2:
305
+ integer, decimal = integer_decimal
306
+ else:
307
+ raise ValueError(
308
+ f"The value string: '${value_string}' has more than one point in it."
309
+ )
310
+
311
+ result = verbalize_cardinal(integer)
312
+
313
+ decimal = decimal.rstrip('0')
314
+ if decimal:
315
+ # '.22' is verbalized as '零点二二'
316
+ # '3.20' is verbalized as '三点二
317
+ result = result if result else "零"
318
+ result += '点' + verbalize_digit(decimal)
319
+ return result
320
+
321
+
322
+ if __name__ == "__main__":
323
+
324
+ text = ""
325
+ text = num2str(text)
326
+ print(text)
327
+ pass
g2p/sources/bpmf_2_pinyin.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b ㄅ
2
+ p ㄆ
3
+ m ㄇ
4
+ f ㄈ
5
+ d ㄉ
6
+ t ㄊ
7
+ n ㄋ
8
+ l ㄌ
9
+ g ㄍ
10
+ k ㄎ
11
+ h ㄏ
12
+ j ㄐ
13
+ q ㄑ
14
+ x ㄒ
15
+ zh ㄓ
16
+ ch ㄔ
17
+ sh ㄕ
18
+ r ㄖ
19
+ z ㄗ
20
+ c ㄘ
21
+ s ㄙ
22
+ i ㄧ
23
+ u ㄨ
24
+ v ㄩ
25
+ a ㄚ
26
+ o ㄛ
27
+ e ㄜ
28
+ e ㄝ
29
+ ai ㄞ
30
+ ei ㄟ
31
+ ao ㄠ
32
+ ou ㄡ
33
+ an ㄢ
34
+ en ㄣ
35
+ ang ㄤ
36
+ eng ㄥ
37
+ er ㄦ
38
+ 2 ˊ
39
+ 3 ˇ
40
+ 4 ˋ
41
+ 0 ˙
g2p/sources/chinese_lexicon.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
3
+ size 14779443
g2p/sources/g2p_chinese_model/config.json ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
3
+ "architectures": [
4
+ "BertPoly"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5",
20
+ "6": "LABEL_6",
21
+ "7": "LABEL_7",
22
+ "8": "LABEL_8",
23
+ "9": "LABEL_9",
24
+ "10": "LABEL_10",
25
+ "11": "LABEL_11",
26
+ "12": "LABEL_12",
27
+ "13": "LABEL_13",
28
+ "14": "LABEL_14",
29
+ "15": "LABEL_15",
30
+ "16": "LABEL_16",
31
+ "17": "LABEL_17",
32
+ "18": "LABEL_18",
33
+ "19": "LABEL_19",
34
+ "20": "LABEL_20",
35
+ "21": "LABEL_21",
36
+ "22": "LABEL_22",
37
+ "23": "LABEL_23",
38
+ "24": "LABEL_24",
39
+ "25": "LABEL_25",
40
+ "26": "LABEL_26",
41
+ "27": "LABEL_27",
42
+ "28": "LABEL_28",
43
+ "29": "LABEL_29",
44
+ "30": "LABEL_30",
45
+ "31": "LABEL_31",
46
+ "32": "LABEL_32",
47
+ "33": "LABEL_33",
48
+ "34": "LABEL_34",
49
+ "35": "LABEL_35",
50
+ "36": "LABEL_36",
51
+ "37": "LABEL_37",
52
+ "38": "LABEL_38",
53
+ "39": "LABEL_39",
54
+ "40": "LABEL_40",
55
+ "41": "LABEL_41",
56
+ "42": "LABEL_42",
57
+ "43": "LABEL_43",
58
+ "44": "LABEL_44",
59
+ "45": "LABEL_45",
60
+ "46": "LABEL_46",
61
+ "47": "LABEL_47",
62
+ "48": "LABEL_48",
63
+ "49": "LABEL_49",
64
+ "50": "LABEL_50",
65
+ "51": "LABEL_51",
66
+ "52": "LABEL_52",
67
+ "53": "LABEL_53",
68
+ "54": "LABEL_54",
69
+ "55": "LABEL_55",
70
+ "56": "LABEL_56",
71
+ "57": "LABEL_57",
72
+ "58": "LABEL_58",
73
+ "59": "LABEL_59",
74
+ "60": "LABEL_60",
75
+ "61": "LABEL_61",
76
+ "62": "LABEL_62",
77
+ "63": "LABEL_63",
78
+ "64": "LABEL_64",
79
+ "65": "LABEL_65",
80
+ "66": "LABEL_66",
81
+ "67": "LABEL_67",
82
+ "68": "LABEL_68",
83
+ "69": "LABEL_69",
84
+ "70": "LABEL_70",
85
+ "71": "LABEL_71",
86
+ "72": "LABEL_72",
87
+ "73": "LABEL_73",
88
+ "74": "LABEL_74",
89
+ "75": "LABEL_75",
90
+ "76": "LABEL_76",
91
+ "77": "LABEL_77",
92
+ "78": "LABEL_78",
93
+ "79": "LABEL_79",
94
+ "80": "LABEL_80",
95
+ "81": "LABEL_81",
96
+ "82": "LABEL_82",
97
+ "83": "LABEL_83",
98
+ "84": "LABEL_84",
99
+ "85": "LABEL_85",
100
+ "86": "LABEL_86",
101
+ "87": "LABEL_87",
102
+ "88": "LABEL_88",
103
+ "89": "LABEL_89",
104
+ "90": "LABEL_90",
105
+ "91": "LABEL_91",
106
+ "92": "LABEL_92",
107
+ "93": "LABEL_93",
108
+ "94": "LABEL_94",
109
+ "95": "LABEL_95",
110
+ "96": "LABEL_96",
111
+ "97": "LABEL_97",
112
+ "98": "LABEL_98",
113
+ "99": "LABEL_99",
114
+ "100": "LABEL_100",
115
+ "101": "LABEL_101",
116
+ "102": "LABEL_102",
117
+ "103": "LABEL_103",
118
+ "104": "LABEL_104",
119
+ "105": "LABEL_105",
120
+ "106": "LABEL_106",
121
+ "107": "LABEL_107",
122
+ "108": "LABEL_108",
123
+ "109": "LABEL_109",
124
+ "110": "LABEL_110",
125
+ "111": "LABEL_111",
126
+ "112": "LABEL_112",
127
+ "113": "LABEL_113",
128
+ "114": "LABEL_114",
129
+ "115": "LABEL_115",
130
+ "116": "LABEL_116",
131
+ "117": "LABEL_117",
132
+ "118": "LABEL_118",
133
+ "119": "LABEL_119",
134
+ "120": "LABEL_120",
135
+ "121": "LABEL_121",
136
+ "122": "LABEL_122",
137
+ "123": "LABEL_123",
138
+ "124": "LABEL_124",
139
+ "125": "LABEL_125",
140
+ "126": "LABEL_126",
141
+ "127": "LABEL_127",
142
+ "128": "LABEL_128",
143
+ "129": "LABEL_129",
144
+ "130": "LABEL_130",
145
+ "131": "LABEL_131",
146
+ "132": "LABEL_132",
147
+ "133": "LABEL_133",
148
+ "134": "LABEL_134",
149
+ "135": "LABEL_135",
150
+ "136": "LABEL_136",
151
+ "137": "LABEL_137",
152
+ "138": "LABEL_138",
153
+ "139": "LABEL_139",
154
+ "140": "LABEL_140",
155
+ "141": "LABEL_141",
156
+ "142": "LABEL_142",
157
+ "143": "LABEL_143",
158
+ "144": "LABEL_144",
159
+ "145": "LABEL_145",
160
+ "146": "LABEL_146",
161
+ "147": "LABEL_147",
162
+ "148": "LABEL_148",
163
+ "149": "LABEL_149",
164
+ "150": "LABEL_150",
165
+ "151": "LABEL_151",
166
+ "152": "LABEL_152",
167
+ "153": "LABEL_153",
168
+ "154": "LABEL_154",
169
+ "155": "LABEL_155",
170
+ "156": "LABEL_156",
171
+ "157": "LABEL_157",
172
+ "158": "LABEL_158",
173
+ "159": "LABEL_159",
174
+ "160": "LABEL_160",
175
+ "161": "LABEL_161",
176
+ "162": "LABEL_162",
177
+ "163": "LABEL_163",
178
+ "164": "LABEL_164",
179
+ "165": "LABEL_165",
180
+ "166": "LABEL_166",
181
+ "167": "LABEL_167",
182
+ "168": "LABEL_168",
183
+ "169": "LABEL_169",
184
+ "170": "LABEL_170",
185
+ "171": "LABEL_171",
186
+ "172": "LABEL_172",
187
+ "173": "LABEL_173",
188
+ "174": "LABEL_174",
189
+ "175": "LABEL_175",
190
+ "176": "LABEL_176",
191
+ "177": "LABEL_177",
192
+ "178": "LABEL_178",
193
+ "179": "LABEL_179",
194
+ "180": "LABEL_180",
195
+ "181": "LABEL_181",
196
+ "182": "LABEL_182",
197
+ "183": "LABEL_183",
198
+ "184": "LABEL_184",
199
+ "185": "LABEL_185",
200
+ "186": "LABEL_186",
201
+ "187": "LABEL_187",
202
+ "188": "LABEL_188",
203
+ "189": "LABEL_189",
204
+ "190": "LABEL_190",
205
+ "191": "LABEL_191",
206
+ "192": "LABEL_192",
207
+ "193": "LABEL_193",
208
+ "194": "LABEL_194",
209
+ "195": "LABEL_195",
210
+ "196": "LABEL_196",
211
+ "197": "LABEL_197",
212
+ "198": "LABEL_198",
213
+ "199": "LABEL_199",
214
+ "200": "LABEL_200",
215
+ "201": "LABEL_201",
216
+ "202": "LABEL_202",
217
+ "203": "LABEL_203",
218
+ "204": "LABEL_204",
219
+ "205": "LABEL_205",
220
+ "206": "LABEL_206",
221
+ "207": "LABEL_207",
222
+ "208": "LABEL_208",
223
+ "209": "LABEL_209",
224
+ "210": "LABEL_210",
225
+ "211": "LABEL_211",
226
+ "212": "LABEL_212",
227
+ "213": "LABEL_213",
228
+ "214": "LABEL_214",
229
+ "215": "LABEL_215",
230
+ "216": "LABEL_216",
231
+ "217": "LABEL_217",
232
+ "218": "LABEL_218",
233
+ "219": "LABEL_219",
234
+ "220": "LABEL_220",
235
+ "221": "LABEL_221",
236
+ "222": "LABEL_222",
237
+ "223": "LABEL_223",
238
+ "224": "LABEL_224",
239
+ "225": "LABEL_225",
240
+ "226": "LABEL_226",
241
+ "227": "LABEL_227",
242
+ "228": "LABEL_228",
243
+ "229": "LABEL_229",
244
+ "230": "LABEL_230",
245
+ "231": "LABEL_231",
246
+ "232": "LABEL_232",
247
+ "233": "LABEL_233",
248
+ "234": "LABEL_234",
249
+ "235": "LABEL_235",
250
+ "236": "LABEL_236",
251
+ "237": "LABEL_237",
252
+ "238": "LABEL_238",
253
+ "239": "LABEL_239",
254
+ "240": "LABEL_240",
255
+ "241": "LABEL_241",
256
+ "242": "LABEL_242",
257
+ "243": "LABEL_243",
258
+ "244": "LABEL_244",
259
+ "245": "LABEL_245",
260
+ "246": "LABEL_246",
261
+ "247": "LABEL_247",
262
+ "248": "LABEL_248",
263
+ "249": "LABEL_249",
264
+ "250": "LABEL_250",
265
+ "251": "LABEL_251",
266
+ "252": "LABEL_252",
267
+ "253": "LABEL_253",
268
+ "254": "LABEL_254",
269
+ "255": "LABEL_255",
270
+ "256": "LABEL_256",
271
+ "257": "LABEL_257",
272
+ "258": "LABEL_258",
273
+ "259": "LABEL_259",
274
+ "260": "LABEL_260",
275
+ "261": "LABEL_261",
276
+ "262": "LABEL_262",
277
+ "263": "LABEL_263",
278
+ "264": "LABEL_264",
279
+ "265": "LABEL_265",
280
+ "266": "LABEL_266",
281
+ "267": "LABEL_267",
282
+ "268": "LABEL_268",
283
+ "269": "LABEL_269",
284
+ "270": "LABEL_270",
285
+ "271": "LABEL_271",
286
+ "272": "LABEL_272",
287
+ "273": "LABEL_273",
288
+ "274": "LABEL_274",
289
+ "275": "LABEL_275",
290
+ "276": "LABEL_276",
291
+ "277": "LABEL_277",
292
+ "278": "LABEL_278",
293
+ "279": "LABEL_279",
294
+ "280": "LABEL_280",
295
+ "281": "LABEL_281",
296
+ "282": "LABEL_282",
297
+ "283": "LABEL_283",
298
+ "284": "LABEL_284",
299
+ "285": "LABEL_285",
300
+ "286": "LABEL_286",
301
+ "287": "LABEL_287",
302
+ "288": "LABEL_288",
303
+ "289": "LABEL_289",
304
+ "290": "LABEL_290",
305
+ "291": "LABEL_291",
306
+ "292": "LABEL_292",
307
+ "293": "LABEL_293",
308
+ "294": "LABEL_294",
309
+ "295": "LABEL_295",
310
+ "296": "LABEL_296",
311
+ "297": "LABEL_297",
312
+ "298": "LABEL_298",
313
+ "299": "LABEL_299",
314
+ "300": "LABEL_300",
315
+ "301": "LABEL_301",
316
+ "302": "LABEL_302",
317
+ "303": "LABEL_303",
318
+ "304": "LABEL_304",
319
+ "305": "LABEL_305",
320
+ "306": "LABEL_306",
321
+ "307": "LABEL_307",
322
+ "308": "LABEL_308",
323
+ "309": "LABEL_309",
324
+ "310": "LABEL_310",
325
+ "311": "LABEL_311",
326
+ "312": "LABEL_312",
327
+ "313": "LABEL_313",
328
+ "314": "LABEL_314",
329
+ "315": "LABEL_315",
330
+ "316": "LABEL_316",
331
+ "317": "LABEL_317",
332
+ "318": "LABEL_318",
333
+ "319": "LABEL_319",
334
+ "320": "LABEL_320",
335
+ "321": "LABEL_321",
336
+ "322": "LABEL_322",
337
+ "323": "LABEL_323",
338
+ "324": "LABEL_324",
339
+ "325": "LABEL_325",
340
+ "326": "LABEL_326",
341
+ "327": "LABEL_327",
342
+ "328": "LABEL_328",
343
+ "329": "LABEL_329",
344
+ "330": "LABEL_330",
345
+ "331": "LABEL_331",
346
+ "332": "LABEL_332",
347
+ "333": "LABEL_333",
348
+ "334": "LABEL_334",
349
+ "335": "LABEL_335",
350
+ "336": "LABEL_336",
351
+ "337": "LABEL_337",
352
+ "338": "LABEL_338",
353
+ "339": "LABEL_339",
354
+ "340": "LABEL_340",
355
+ "341": "LABEL_341",
356
+ "342": "LABEL_342",
357
+ "343": "LABEL_343",
358
+ "344": "LABEL_344",
359
+ "345": "LABEL_345",
360
+ "346": "LABEL_346",
361
+ "347": "LABEL_347",
362
+ "348": "LABEL_348",
363
+ "349": "LABEL_349",
364
+ "350": "LABEL_350",
365
+ "351": "LABEL_351",
366
+ "352": "LABEL_352",
367
+ "353": "LABEL_353",
368
+ "354": "LABEL_354",
369
+ "355": "LABEL_355",
370
+ "356": "LABEL_356",
371
+ "357": "LABEL_357",
372
+ "358": "LABEL_358",
373
+ "359": "LABEL_359",
374
+ "360": "LABEL_360",
375
+ "361": "LABEL_361",
376
+ "362": "LABEL_362",
377
+ "363": "LABEL_363",
378
+ "364": "LABEL_364",
379
+ "365": "LABEL_365",
380
+ "366": "LABEL_366",
381
+ "367": "LABEL_367",
382
+ "368": "LABEL_368",
383
+ "369": "LABEL_369",
384
+ "370": "LABEL_370",
385
+ "371": "LABEL_371",
386
+ "372": "LABEL_372",
387
+ "373": "LABEL_373",
388
+ "374": "LABEL_374",
389
+ "375": "LABEL_375",
390
+ "376": "LABEL_376",
391
+ "377": "LABEL_377",
392
+ "378": "LABEL_378",
393
+ "379": "LABEL_379",
394
+ "380": "LABEL_380",
395
+ "381": "LABEL_381",
396
+ "382": "LABEL_382",
397
+ "383": "LABEL_383",
398
+ "384": "LABEL_384",
399
+ "385": "LABEL_385",
400
+ "386": "LABEL_386",
401
+ "387": "LABEL_387",
402
+ "388": "LABEL_388",
403
+ "389": "LABEL_389",
404
+ "390": "LABEL_390"
405
+ },
406
+ "initializer_range": 0.02,
407
+ "intermediate_size": 1536,
408
+ "label2id": {
409
+ "LABEL_0": 0,
410
+ "LABEL_1": 1,
411
+ "LABEL_10": 10,
412
+ "LABEL_100": 100,
413
+ "LABEL_101": 101,
414
+ "LABEL_102": 102,
415
+ "LABEL_103": 103,
416
+ "LABEL_104": 104,
417
+ "LABEL_105": 105,
418
+ "LABEL_106": 106,
419
+ "LABEL_107": 107,
420
+ "LABEL_108": 108,
421
+ "LABEL_109": 109,
422
+ "LABEL_11": 11,
423
+ "LABEL_110": 110,
424
+ "LABEL_111": 111,
425
+ "LABEL_112": 112,
426
+ "LABEL_113": 113,
427
+ "LABEL_114": 114,
428
+ "LABEL_115": 115,
429
+ "LABEL_116": 116,
430
+ "LABEL_117": 117,
431
+ "LABEL_118": 118,
432
+ "LABEL_119": 119,
433
+ "LABEL_12": 12,
434
+ "LABEL_120": 120,
435
+ "LABEL_121": 121,
436
+ "LABEL_122": 122,
437
+ "LABEL_123": 123,
438
+ "LABEL_124": 124,
439
+ "LABEL_125": 125,
440
+ "LABEL_126": 126,
441
+ "LABEL_127": 127,
442
+ "LABEL_128": 128,
443
+ "LABEL_129": 129,
444
+ "LABEL_13": 13,
445
+ "LABEL_130": 130,
446
+ "LABEL_131": 131,
447
+ "LABEL_132": 132,
448
+ "LABEL_133": 133,
449
+ "LABEL_134": 134,
450
+ "LABEL_135": 135,
451
+ "LABEL_136": 136,
452
+ "LABEL_137": 137,
453
+ "LABEL_138": 138,
454
+ "LABEL_139": 139,
455
+ "LABEL_14": 14,
456
+ "LABEL_140": 140,
457
+ "LABEL_141": 141,
458
+ "LABEL_142": 142,
459
+ "LABEL_143": 143,
460
+ "LABEL_144": 144,
461
+ "LABEL_145": 145,
462
+ "LABEL_146": 146,
463
+ "LABEL_147": 147,
464
+ "LABEL_148": 148,
465
+ "LABEL_149": 149,
466
+ "LABEL_15": 15,
467
+ "LABEL_150": 150,
468
+ "LABEL_151": 151,
469
+ "LABEL_152": 152,
470
+ "LABEL_153": 153,
471
+ "LABEL_154": 154,
472
+ "LABEL_155": 155,
473
+ "LABEL_156": 156,
474
+ "LABEL_157": 157,
475
+ "LABEL_158": 158,
476
+ "LABEL_159": 159,
477
+ "LABEL_16": 16,
478
+ "LABEL_160": 160,
479
+ "LABEL_161": 161,
480
+ "LABEL_162": 162,
481
+ "LABEL_163": 163,
482
+ "LABEL_164": 164,
483
+ "LABEL_165": 165,
484
+ "LABEL_166": 166,
485
+ "LABEL_167": 167,
486
+ "LABEL_168": 168,
487
+ "LABEL_169": 169,
488
+ "LABEL_17": 17,
489
+ "LABEL_170": 170,
490
+ "LABEL_171": 171,
491
+ "LABEL_172": 172,
492
+ "LABEL_173": 173,
493
+ "LABEL_174": 174,
494
+ "LABEL_175": 175,
495
+ "LABEL_176": 176,
496
+ "LABEL_177": 177,
497
+ "LABEL_178": 178,
498
+ "LABEL_179": 179,
499
+ "LABEL_18": 18,
500
+ "LABEL_180": 180,
501
+ "LABEL_181": 181,
502
+ "LABEL_182": 182,
503
+ "LABEL_183": 183,
504
+ "LABEL_184": 184,
505
+ "LABEL_185": 185,
506
+ "LABEL_186": 186,
507
+ "LABEL_187": 187,
508
+ "LABEL_188": 188,
509
+ "LABEL_189": 189,
510
+ "LABEL_19": 19,
511
+ "LABEL_190": 190,
512
+ "LABEL_191": 191,
513
+ "LABEL_192": 192,
514
+ "LABEL_193": 193,
515
+ "LABEL_194": 194,
516
+ "LABEL_195": 195,
517
+ "LABEL_196": 196,
518
+ "LABEL_197": 197,
519
+ "LABEL_198": 198,
520
+ "LABEL_199": 199,
521
+ "LABEL_2": 2,
522
+ "LABEL_20": 20,
523
+ "LABEL_200": 200,
524
+ "LABEL_201": 201,
525
+ "LABEL_202": 202,
526
+ "LABEL_203": 203,
527
+ "LABEL_204": 204,
528
+ "LABEL_205": 205,
529
+ "LABEL_206": 206,
530
+ "LABEL_207": 207,
531
+ "LABEL_208": 208,
532
+ "LABEL_209": 209,
533
+ "LABEL_21": 21,
534
+ "LABEL_210": 210,
535
+ "LABEL_211": 211,
536
+ "LABEL_212": 212,
537
+ "LABEL_213": 213,
538
+ "LABEL_214": 214,
539
+ "LABEL_215": 215,
540
+ "LABEL_216": 216,
541
+ "LABEL_217": 217,
542
+ "LABEL_218": 218,
543
+ "LABEL_219": 219,
544
+ "LABEL_22": 22,
545
+ "LABEL_220": 220,
546
+ "LABEL_221": 221,
547
+ "LABEL_222": 222,
548
+ "LABEL_223": 223,
549
+ "LABEL_224": 224,
550
+ "LABEL_225": 225,
551
+ "LABEL_226": 226,
552
+ "LABEL_227": 227,
553
+ "LABEL_228": 228,
554
+ "LABEL_229": 229,
555
+ "LABEL_23": 23,
556
+ "LABEL_230": 230,
557
+ "LABEL_231": 231,
558
+ "LABEL_232": 232,
559
+ "LABEL_233": 233,
560
+ "LABEL_234": 234,
561
+ "LABEL_235": 235,
562
+ "LABEL_236": 236,
563
+ "LABEL_237": 237,
564
+ "LABEL_238": 238,
565
+ "LABEL_239": 239,
566
+ "LABEL_24": 24,
567
+ "LABEL_240": 240,
568
+ "LABEL_241": 241,
569
+ "LABEL_242": 242,
570
+ "LABEL_243": 243,
571
+ "LABEL_244": 244,
572
+ "LABEL_245": 245,
573
+ "LABEL_246": 246,
574
+ "LABEL_247": 247,
575
+ "LABEL_248": 248,
576
+ "LABEL_249": 249,
577
+ "LABEL_25": 25,
578
+ "LABEL_250": 250,
579
+ "LABEL_251": 251,
580
+ "LABEL_252": 252,
581
+ "LABEL_253": 253,
582
+ "LABEL_254": 254,
583
+ "LABEL_255": 255,
584
+ "LABEL_256": 256,
585
+ "LABEL_257": 257,
586
+ "LABEL_258": 258,
587
+ "LABEL_259": 259,
588
+ "LABEL_26": 26,
589
+ "LABEL_260": 260,
590
+ "LABEL_261": 261,
591
+ "LABEL_262": 262,
592
+ "LABEL_263": 263,
593
+ "LABEL_264": 264,
594
+ "LABEL_265": 265,
595
+ "LABEL_266": 266,
596
+ "LABEL_267": 267,
597
+ "LABEL_268": 268,
598
+ "LABEL_269": 269,
599
+ "LABEL_27": 27,
600
+ "LABEL_270": 270,
601
+ "LABEL_271": 271,
602
+ "LABEL_272": 272,
603
+ "LABEL_273": 273,
604
+ "LABEL_274": 274,
605
+ "LABEL_275": 275,
606
+ "LABEL_276": 276,
607
+ "LABEL_277": 277,
608
+ "LABEL_278": 278,
609
+ "LABEL_279": 279,
610
+ "LABEL_28": 28,
611
+ "LABEL_280": 280,
612
+ "LABEL_281": 281,
613
+ "LABEL_282": 282,
614
+ "LABEL_283": 283,
615
+ "LABEL_284": 284,
616
+ "LABEL_285": 285,
617
+ "LABEL_286": 286,
618
+ "LABEL_287": 287,
619
+ "LABEL_288": 288,
620
+ "LABEL_289": 289,
621
+ "LABEL_29": 29,
622
+ "LABEL_290": 290,
623
+ "LABEL_291": 291,
624
+ "LABEL_292": 292,
625
+ "LABEL_293": 293,
626
+ "LABEL_294": 294,
627
+ "LABEL_295": 295,
628
+ "LABEL_296": 296,
629
+ "LABEL_297": 297,
630
+ "LABEL_298": 298,
631
+ "LABEL_299": 299,
632
+ "LABEL_3": 3,
633
+ "LABEL_30": 30,
634
+ "LABEL_300": 300,
635
+ "LABEL_301": 301,
636
+ "LABEL_302": 302,
637
+ "LABEL_303": 303,
638
+ "LABEL_304": 304,
639
+ "LABEL_305": 305,
640
+ "LABEL_306": 306,
641
+ "LABEL_307": 307,
642
+ "LABEL_308": 308,
643
+ "LABEL_309": 309,
644
+ "LABEL_31": 31,
645
+ "LABEL_310": 310,
646
+ "LABEL_311": 311,
647
+ "LABEL_312": 312,
648
+ "LABEL_313": 313,
649
+ "LABEL_314": 314,
650
+ "LABEL_315": 315,
651
+ "LABEL_316": 316,
652
+ "LABEL_317": 317,
653
+ "LABEL_318": 318,
654
+ "LABEL_319": 319,
655
+ "LABEL_32": 32,
656
+ "LABEL_320": 320,
657
+ "LABEL_321": 321,
658
+ "LABEL_322": 322,
659
+ "LABEL_323": 323,
660
+ "LABEL_324": 324,
661
+ "LABEL_325": 325,
662
+ "LABEL_326": 326,
663
+ "LABEL_327": 327,
664
+ "LABEL_328": 328,
665
+ "LABEL_329": 329,
666
+ "LABEL_33": 33,
667
+ "LABEL_330": 330,
668
+ "LABEL_331": 331,
669
+ "LABEL_332": 332,
670
+ "LABEL_333": 333,
671
+ "LABEL_334": 334,
672
+ "LABEL_335": 335,
673
+ "LABEL_336": 336,
674
+ "LABEL_337": 337,
675
+ "LABEL_338": 338,
676
+ "LABEL_339": 339,
677
+ "LABEL_34": 34,
678
+ "LABEL_340": 340,
679
+ "LABEL_341": 341,
680
+ "LABEL_342": 342,
681
+ "LABEL_343": 343,
682
+ "LABEL_344": 344,
683
+ "LABEL_345": 345,
684
+ "LABEL_346": 346,
685
+ "LABEL_347": 347,
686
+ "LABEL_348": 348,
687
+ "LABEL_349": 349,
688
+ "LABEL_35": 35,
689
+ "LABEL_350": 350,
690
+ "LABEL_351": 351,
691
+ "LABEL_352": 352,
692
+ "LABEL_353": 353,
693
+ "LABEL_354": 354,
694
+ "LABEL_355": 355,
695
+ "LABEL_356": 356,
696
+ "LABEL_357": 357,
697
+ "LABEL_358": 358,
698
+ "LABEL_359": 359,
699
+ "LABEL_36": 36,
700
+ "LABEL_360": 360,
701
+ "LABEL_361": 361,
702
+ "LABEL_362": 362,
703
+ "LABEL_363": 363,
704
+ "LABEL_364": 364,
705
+ "LABEL_365": 365,
706
+ "LABEL_366": 366,
707
+ "LABEL_367": 367,
708
+ "LABEL_368": 368,
709
+ "LABEL_369": 369,
710
+ "LABEL_37": 37,
711
+ "LABEL_370": 370,
712
+ "LABEL_371": 371,
713
+ "LABEL_372": 372,
714
+ "LABEL_373": 373,
715
+ "LABEL_374": 374,
716
+ "LABEL_375": 375,
717
+ "LABEL_376": 376,
718
+ "LABEL_377": 377,
719
+ "LABEL_378": 378,
720
+ "LABEL_379": 379,
721
+ "LABEL_38": 38,
722
+ "LABEL_380": 380,
723
+ "LABEL_381": 381,
724
+ "LABEL_382": 382,
725
+ "LABEL_383": 383,
726
+ "LABEL_384": 384,
727
+ "LABEL_385": 385,
728
+ "LABEL_386": 386,
729
+ "LABEL_387": 387,
730
+ "LABEL_388": 388,
731
+ "LABEL_389": 389,
732
+ "LABEL_39": 39,
733
+ "LABEL_390": 390,
734
+ "LABEL_4": 4,
735
+ "LABEL_40": 40,
736
+ "LABEL_41": 41,
737
+ "LABEL_42": 42,
738
+ "LABEL_43": 43,
739
+ "LABEL_44": 44,
740
+ "LABEL_45": 45,
741
+ "LABEL_46": 46,
742
+ "LABEL_47": 47,
743
+ "LABEL_48": 48,
744
+ "LABEL_49": 49,
745
+ "LABEL_5": 5,
746
+ "LABEL_50": 50,
747
+ "LABEL_51": 51,
748
+ "LABEL_52": 52,
749
+ "LABEL_53": 53,
750
+ "LABEL_54": 54,
751
+ "LABEL_55": 55,
752
+ "LABEL_56": 56,
753
+ "LABEL_57": 57,
754
+ "LABEL_58": 58,
755
+ "LABEL_59": 59,
756
+ "LABEL_6": 6,
757
+ "LABEL_60": 60,
758
+ "LABEL_61": 61,
759
+ "LABEL_62": 62,
760
+ "LABEL_63": 63,
761
+ "LABEL_64": 64,
762
+ "LABEL_65": 65,
763
+ "LABEL_66": 66,
764
+ "LABEL_67": 67,
765
+ "LABEL_68": 68,
766
+ "LABEL_69": 69,
767
+ "LABEL_7": 7,
768
+ "LABEL_70": 70,
769
+ "LABEL_71": 71,
770
+ "LABEL_72": 72,
771
+ "LABEL_73": 73,
772
+ "LABEL_74": 74,
773
+ "LABEL_75": 75,
774
+ "LABEL_76": 76,
775
+ "LABEL_77": 77,
776
+ "LABEL_78": 78,
777
+ "LABEL_79": 79,
778
+ "LABEL_8": 8,
779
+ "LABEL_80": 80,
780
+ "LABEL_81": 81,
781
+ "LABEL_82": 82,
782
+ "LABEL_83": 83,
783
+ "LABEL_84": 84,
784
+ "LABEL_85": 85,
785
+ "LABEL_86": 86,
786
+ "LABEL_87": 87,
787
+ "LABEL_88": 88,
788
+ "LABEL_89": 89,
789
+ "LABEL_9": 9,
790
+ "LABEL_90": 90,
791
+ "LABEL_91": 91,
792
+ "LABEL_92": 92,
793
+ "LABEL_93": 93,
794
+ "LABEL_94": 94,
795
+ "LABEL_95": 95,
796
+ "LABEL_96": 96,
797
+ "LABEL_97": 97,
798
+ "LABEL_98": 98,
799
+ "LABEL_99": 99
800
+ },
801
+ "layer_norm_eps": 1e-12,
802
+ "max_position_embeddings": 512,
803
+ "model_type": "bert",
804
+ "num_attention_heads": 12,
805
+ "num_hidden_layers": 6,
806
+ "num_relation_heads": 32,
807
+ "pad_token_id": 0,
808
+ "pooler_fc_size": 768,
809
+ "pooler_num_attention_heads": 12,
810
+ "pooler_num_fc_layers": 3,
811
+ "pooler_size_per_head": 128,
812
+ "pooler_type": "first_token_transform",
813
+ "position_embedding_type": "absolute",
814
+ "torch_dtype": "float32",
815
+ "transformers_version": "4.44.1",
816
+ "type_vocab_size": 2,
817
+ "use_cache": true,
818
+ "vocab_size": 21128
819
+ }
g2p/sources/g2p_chinese_model/poly_bert_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
3
+ size 76583333
g2p/sources/g2p_chinese_model/polychar.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+
10
+
11
+
12
+
13
+
14
+
15
+ 便
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+ 宿
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
g2p/sources/g2p_chinese_model/polydict.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": "丧{sang1}",
3
+ "2": "丧{sang4}",
4
+ "3": "中{zhong1}",
5
+ "4": "中{zhong4}",
6
+ "5": "为{wei2}",
7
+ "6": "为{wei4}",
8
+ "7": "乌{wu1}",
9
+ "8": "乌{wu4}",
10
+ "9": "乐{lao4}",
11
+ "10": "乐{le4}",
12
+ "11": "乐{le5}",
13
+ "12": "乐{yao4}",
14
+ "13": "乐{yve4}",
15
+ "14": "了{le5}",
16
+ "15": "了{liao3}",
17
+ "16": "了{liao5}",
18
+ "17": "什{shen2}",
19
+ "18": "什{shi2}",
20
+ "19": "仔{zai3}",
21
+ "20": "仔{zai5}",
22
+ "21": "仔{zi3}",
23
+ "22": "仔{zi5}",
24
+ "23": "令{ling2}",
25
+ "24": "令{ling4}",
26
+ "25": "任{ren2}",
27
+ "26": "任{ren4}",
28
+ "27": "会{hui4}",
29
+ "28": "会{hui5}",
30
+ "29": "会{kuai4}",
31
+ "30": "传{chuan2}",
32
+ "31": "传{zhuan4}",
33
+ "32": "佛{fo2}",
34
+ "33": "佛{fu2}",
35
+ "34": "供{gong1}",
36
+ "35": "供{gong4}",
37
+ "36": "便{bian4}",
38
+ "37": "便{pian2}",
39
+ "38": "倒{dao3}",
40
+ "39": "倒{dao4}",
41
+ "40": "假{jia3}",
42
+ "41": "假{jia4}",
43
+ "42": "兴{xing1}",
44
+ "43": "兴{xing4}",
45
+ "44": "冠{guan1}",
46
+ "45": "冠{guan4}",
47
+ "46": "冲{chong1}",
48
+ "47": "冲{chong4}",
49
+ "48": "几{ji1}",
50
+ "49": "几{ji2}",
51
+ "50": "几{ji3}",
52
+ "51": "分{fen1}",
53
+ "52": "分{fen4}",
54
+ "53": "分{fen5}",
55
+ "54": "切{qie1}",
56
+ "55": "切{qie4}",
57
+ "56": "划{hua2}",
58
+ "57": "划{hua4}",
59
+ "58": "划{hua5}",
60
+ "59": "创{chuang1}",
61
+ "60": "创{chuang4}",
62
+ "61": "剥{bao1}",
63
+ "62": "剥{bo1}",
64
+ "63": "勒{le4}",
65
+ "64": "勒{le5}",
66
+ "65": "勒{lei1}",
67
+ "66": "区{ou1}",
68
+ "67": "区{qu1}",
69
+ "68": "华{hua2}",
70
+ "69": "华{hua4}",
71
+ "70": "单{chan2}",
72
+ "71": "单{dan1}",
73
+ "72": "单{shan4}",
74
+ "73": "卜{bo5}",
75
+ "74": "卜{bu3}",
76
+ "75": "占{zhan1}",
77
+ "76": "占{zhan4}",
78
+ "77": "卡{ka2}",
79
+ "78": "卡{ka3}",
80
+ "79": "卡{qia3}",
81
+ "80": "卷{jvan3}",
82
+ "81": "卷{jvan4}",
83
+ "82": "厦{sha4}",
84
+ "83": "厦{xia4}",
85
+ "84": "参{can1}",
86
+ "85": "参{cen1}",
87
+ "86": "参{shen1}",
88
+ "87": "发{fa1}",
89
+ "88": "发{fa4}",
90
+ "89": "发{fa5}",
91
+ "90": "只{zhi1}",
92
+ "91": "只{zhi3}",
93
+ "92": "号{hao2}",
94
+ "93": "号{hao4}",
95
+ "94": "号{hao5}",
96
+ "95": "同{tong2}",
97
+ "96": "同{tong4}",
98
+ "97": "同{tong5}",
99
+ "98": "吐{tu2}",
100
+ "99": "吐{tu3}",
101
+ "100": "吐{tu4}",
102
+ "101": "和{he2}",
103
+ "102": "和{he4}",
104
+ "103": "和{he5}",
105
+ "104": "和{huo2}",
106
+ "105": "和{huo4}",
107
+ "106": "和{huo5}",
108
+ "107": "喝{he1}",
109
+ "108": "喝{he4}",
110
+ "109": "圈{jvan4}",
111
+ "110": "圈{qvan1}",
112
+ "111": "圈{qvan5}",
113
+ "112": "地{de5}",
114
+ "113": "地{di4}",
115
+ "114": "地{di5}",
116
+ "115": "塞{sai1}",
117
+ "116": "塞{sai2}",
118
+ "117": "塞{sai4}",
119
+ "118": "塞{se4}",
120
+ "119": "壳{ke2}",
121
+ "120": "壳{qiao4}",
122
+ "121": "处{chu3}",
123
+ "122": "处{chu4}",
124
+ "123": "奇{ji1}",
125
+ "124": "奇{qi2}",
126
+ "125": "奔{ben1}",
127
+ "126": "奔{ben4}",
128
+ "127": "好{hao3}",
129
+ "128": "好{hao4}",
130
+ "129": "好{hao5}",
131
+ "130": "宁{ning2}",
132
+ "131": "宁{ning4}",
133
+ "132": "宁{ning5}",
134
+ "133": "宿{su4}",
135
+ "134": "宿{xiu3}",
136
+ "135": "宿{xiu4}",
137
+ "136": "将{jiang1}",
138
+ "137": "将{jiang4}",
139
+ "138": "少{shao3}",
140
+ "139": "少{shao4}",
141
+ "140": "尽{jin3}",
142
+ "141": "尽{jin4}",
143
+ "142": "岗{gang1}",
144
+ "143": "岗{gang3}",
145
+ "144": "差{cha1}",
146
+ "145": "差{cha4}",
147
+ "146": "差{chai1}",
148
+ "147": "差{ci1}",
149
+ "148": "巷{hang4}",
150
+ "149": "巷{xiang4}",
151
+ "150": "帖{tie1}",
152
+ "151": "帖{tie3}",
153
+ "152": "帖{tie4}",
154
+ "153": "干{gan1}",
155
+ "154": "干{gan4}",
156
+ "155": "应{ying1}",
157
+ "156": "应{ying4}",
158
+ "157": "应{ying5}",
159
+ "158": "度{du4}",
160
+ "159": "度{du5}",
161
+ "160": "度{duo2}",
162
+ "161": "弹{dan4}",
163
+ "162": "弹{tan2}",
164
+ "163": "弹{tan5}",
165
+ "164": "强{jiang4}",
166
+ "165": "强{qiang2}",
167
+ "166": "强{qiang3}",
168
+ "167": "当{dang1}",
169
+ "168": "当{dang4}",
170
+ "169": "当{dang5}",
171
+ "170": "待{dai1}",
172
+ "171": "待{dai4}",
173
+ "172": "得{de2}",
174
+ "173": "得{de5}",
175
+ "174": "得{dei3}",
176
+ "175": "得{dei5}",
177
+ "176": "恶{e3}",
178
+ "177": "恶{e4}",
179
+ "178": "恶{wu4}",
180
+ "179": "扁{bian3}",
181
+ "180": "扁{pian1}",
182
+ "181": "扇{shan1}",
183
+ "182": "扇{shan4}",
184
+ "183": "扎{za1}",
185
+ "184": "扎{zha1}",
186
+ "185": "扎{zha2}",
187
+ "186": "扫{sao3}",
188
+ "187": "扫{sao4}",
189
+ "188": "担{dan1}",
190
+ "189": "担{dan4}",
191
+ "190": "担{dan5}",
192
+ "191": "挑{tiao1}",
193
+ "192": "挑{tiao3}",
194
+ "193": "据{jv1}",
195
+ "194": "据{jv4}",
196
+ "195": "撒{sa1}",
197
+ "196": "撒{sa3}",
198
+ "197": "撒{sa5}",
199
+ "198": "教{jiao1}",
200
+ "199": "教{jiao4}",
201
+ "200": "散{san3}",
202
+ "201": "散{san4}",
203
+ "202": "散{san5}",
204
+ "203": "数{shu3}",
205
+ "204": "数{shu4}",
206
+ "205": "数{shu5}",
207
+ "206": "斗{dou3}",
208
+ "207": "斗{dou4}",
209
+ "208": "晃{huang3}",
210
+ "209": "曝{bao4}",
211
+ "210": "曲{qu1}",
212
+ "211": "曲{qu3}",
213
+ "212": "更{geng1}",
214
+ "213": "更{geng4}",
215
+ "214": "曾{ceng1}",
216
+ "215": "曾{ceng2}",
217
+ "216": "曾{zeng1}",
218
+ "217": "朝{chao2}",
219
+ "218": "朝{zhao1}",
220
+ "219": "朴{piao2}",
221
+ "220": "朴{pu2}",
222
+ "221": "朴{pu3}",
223
+ "222": "杆{gan1}",
224
+ "223": "杆{gan3}",
225
+ "224": "查{cha2}",
226
+ "225": "查{zha1}",
227
+ "226": "校{jiao4}",
228
+ "227": "校{xiao4}",
229
+ "228": "模{mo2}",
230
+ "229": "模{mu2}",
231
+ "230": "横{heng2}",
232
+ "231": "横{heng4}",
233
+ "232": "没{mei2}",
234
+ "233": "没{mo4}",
235
+ "234": "泡{pao1}",
236
+ "235": "泡{pao4}",
237
+ "236": "泡{pao5}",
238
+ "237": "济{ji3}",
239
+ "238": "济{ji4}",
240
+ "239": "混{hun2}",
241
+ "240": "混{hun3}",
242
+ "241": "混{hun4}",
243
+ "242": "混{hun5}",
244
+ "243": "漂{piao1}",
245
+ "244": "漂{piao3}",
246
+ "245": "漂{piao4}",
247
+ "246": "炸{zha2}",
248
+ "247": "炸{zha4}",
249
+ "248": "熟{shou2}",
250
+ "249": "熟{shu2}",
251
+ "250": "燕{yan1}",
252
+ "251": "燕{yan4}",
253
+ "252": "片{pian1}",
254
+ "253": "片{pian4}",
255
+ "254": "率{lv4}",
256
+ "255": "率{shuai4}",
257
+ "256": "畜{chu4}",
258
+ "257": "畜{xu4}",
259
+ "258": "的{de5}",
260
+ "259": "的{di1}",
261
+ "260": "的{di2}",
262
+ "261": "的{di4}",
263
+ "262": "的{di5}",
264
+ "263": "盛{cheng2}",
265
+ "264": "盛{sheng4}",
266
+ "265": "相{xiang1}",
267
+ "266": "相{xiang4}",
268
+ "267": "相{xiang5}",
269
+ "268": "省{sheng3}",
270
+ "269": "省{xing3}",
271
+ "270": "看{kan1}",
272
+ "271": "看{kan4}",
273
+ "272": "看{kan5}",
274
+ "273": "着{zhao1}",
275
+ "274": "着{zhao2}",
276
+ "275": "着{zhao5}",
277
+ "276": "着{zhe5}",
278
+ "277": "着{zhuo2}",
279
+ "278": "着{zhuo5}",
280
+ "279": "矫{jiao3}",
281
+ "280": "禁{jin1}",
282
+ "281": "禁{jin4}",
283
+ "282": "种{zhong3}",
284
+ "283": "种{zhong4}",
285
+ "284": "称{chen4}",
286
+ "285": "称{cheng1}",
287
+ "286": "空{kong1}",
288
+ "287": "空{kong4}",
289
+ "288": "答{da1}",
290
+ "289": "答{da2}",
291
+ "290": "粘{nian2}",
292
+ "291": "粘{zhan1}",
293
+ "292": "糊{hu2}",
294
+ "293": "糊{hu5}",
295
+ "294": "系{ji4}",
296
+ "295": "系{xi4}",
297
+ "296": "系{xi5}",
298
+ "297": "累{lei2}",
299
+ "298": "累{lei3}",
300
+ "299": "累{lei4}",
301
+ "300": "累{lei5}",
302
+ "301": "纤{qian4}",
303
+ "302": "纤{xian1}",
304
+ "303": "结{jie1}",
305
+ "304": "结{jie2}",
306
+ "305": "结{jie5}",
307
+ "306": "给{gei3}",
308
+ "307": "给{gei5}",
309
+ "308": "给{ji3}",
310
+ "309": "缝{feng2}",
311
+ "310": "缝{feng4}",
312
+ "311": "缝{feng5}",
313
+ "312": "肖{xiao1}",
314
+ "313": "肖{xiao4}",
315
+ "314": "背{bei1}",
316
+ "315": "背{bei4}",
317
+ "316": "脏{zang1}",
318
+ "317": "脏{zang4}",
319
+ "318": "舍{she3}",
320
+ "319": "舍{she4}",
321
+ "320": "色{se4}",
322
+ "321": "色{shai3}",
323
+ "322": "落{lao4}",
324
+ "323": "落{luo4}",
325
+ "324": "蒙{meng1}",
326
+ "325": "蒙{meng2}",
327
+ "326": "蒙{meng3}",
328
+ "327": "薄{bao2}",
329
+ "328": "薄{bo2}",
330
+ "329": "薄{bo4}",
331
+ "330": "藏{cang2}",
332
+ "331": "藏{zang4}",
333
+ "332": "血{xie3}",
334
+ "333": "血{xue4}",
335
+ "334": "行{hang2}",
336
+ "335": "行{hang5}",
337
+ "336": "行{heng5}",
338
+ "337": "行{xing2}",
339
+ "338": "行{xing4}",
340
+ "339": "要{yao1}",
341
+ "340": "要{yao4}",
342
+ "341": "观{guan1}",
343
+ "342": "观{guan4}",
344
+ "343": "觉{jiao4}",
345
+ "344": "觉{jiao5}",
346
+ "345": "觉{jve2}",
347
+ "346": "角{jiao3}",
348
+ "347": "角{jve2}",
349
+ "348": "解{jie3}",
350
+ "349": "解{jie4}",
351
+ "350": "解{xie4}",
352
+ "351": "说{shui4}",
353
+ "352": "说{shuo1}",
354
+ "353": "调{diao4}",
355
+ "354": "调{tiao2}",
356
+ "355": "踏{ta1}",
357
+ "356": "踏{ta4}",
358
+ "357": "车{che1}",
359
+ "358": "车{jv1}",
360
+ "359": "转{zhuan3}",
361
+ "360": "转{zhuan4}",
362
+ "361": "载{zai3}",
363
+ "362": "载{zai4}",
364
+ "363": "还{hai2}",
365
+ "364": "还{huan2}",
366
+ "365": "遂{sui2}",
367
+ "366": "遂{sui4}",
368
+ "367": "都{dou1}",
369
+ "368": "都{du1}",
370
+ "369": "重{chong2}",
371
+ "370": "重{zhong4}",
372
+ "371": "量{liang2}",
373
+ "372": "量{liang4}",
374
+ "373": "量{liang5}",
375
+ "374": "钻{zuan1}",
376
+ "375": "钻{zuan4}",
377
+ "376": "铺{pu1}",
378
+ "377": "铺{pu4}",
379
+ "378": "长{chang2}",
380
+ "379": "长{chang3}",
381
+ "380": "长{zhang3}",
382
+ "381": "间{jian1}",
383
+ "382": "间{jian4}",
384
+ "383": "降{jiang4}",
385
+ "384": "降{xiang2}",
386
+ "385": "难{nan2}",
387
+ "386": "难{nan4}",
388
+ "387": "难{nan5}",
389
+ "388": "露{lou4}",
390
+ "389": "露{lu4}",
391
+ "390": "鲜{xian1}",
392
+ "391": "鲜{xian3}"
393
+ }
g2p/sources/g2p_chinese_model/polydict_r.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "丧{sang1}": 1,
3
+ "丧{sang4}": 2,
4
+ "中{zhong1}": 3,
5
+ "中{zhong4}": 4,
6
+ "为{wei2}": 5,
7
+ "为{wei4}": 6,
8
+ "乌{wu1}": 7,
9
+ "乌{wu4}": 8,
10
+ "乐{lao4}": 9,
11
+ "乐{le4}": 10,
12
+ "乐{le5}": 11,
13
+ "乐{yao4}": 12,
14
+ "乐{yve4}": 13,
15
+ "了{le5}": 14,
16
+ "了{liao3}": 15,
17
+ "了{liao5}": 16,
18
+ "什{shen2}": 17,
19
+ "什{shi2}": 18,
20
+ "仔{zai3}": 19,
21
+ "仔{zai5}": 20,
22
+ "仔{zi3}": 21,
23
+ "仔{zi5}": 22,
24
+ "令{ling2}": 23,
25
+ "令{ling4}": 24,
26
+ "任{ren2}": 25,
27
+ "任{ren4}": 26,
28
+ "会{hui4}": 27,
29
+ "会{hui5}": 28,
30
+ "会{kuai4}": 29,
31
+ "传{chuan2}": 30,
32
+ "传{zhuan4}": 31,
33
+ "佛{fo2}": 32,
34
+ "佛{fu2}": 33,
35
+ "供{gong1}": 34,
36
+ "供{gong4}": 35,
37
+ "便{bian4}": 36,
38
+ "便{pian2}": 37,
39
+ "倒{dao3}": 38,
40
+ "倒{dao4}": 39,
41
+ "假{jia3}": 40,
42
+ "假{jia4}": 41,
43
+ "兴{xing1}": 42,
44
+ "兴{xing4}": 43,
45
+ "冠{guan1}": 44,
46
+ "冠{guan4}": 45,
47
+ "冲{chong1}": 46,
48
+ "冲{chong4}": 47,
49
+ "几{ji1}": 48,
50
+ "几{ji2}": 49,
51
+ "几{ji3}": 50,
52
+ "分{fen1}": 51,
53
+ "分{fen4}": 52,
54
+ "分{fen5}": 53,
55
+ "切{qie1}": 54,
56
+ "切{qie4}": 55,
57
+ "划{hua2}": 56,
58
+ "划{hua4}": 57,
59
+ "划{hua5}": 58,
60
+ "创{chuang1}": 59,
61
+ "创{chuang4}": 60,
62
+ "剥{bao1}": 61,
63
+ "剥{bo1}": 62,
64
+ "勒{le4}": 63,
65
+ "勒{le5}": 64,
66
+ "勒{lei1}": 65,
67
+ "区{ou1}": 66,
68
+ "区{qu1}": 67,
69
+ "华{hua2}": 68,
70
+ "华{hua4}": 69,
71
+ "单{chan2}": 70,
72
+ "单{dan1}": 71,
73
+ "单{shan4}": 72,
74
+ "卜{bo5}": 73,
75
+ "卜{bu3}": 74,
76
+ "占{zhan1}": 75,
77
+ "占{zhan4}": 76,
78
+ "卡{ka2}": 77,
79
+ "卡{ka3}": 78,
80
+ "卡{qia3}": 79,
81
+ "卷{jvan3}": 80,
82
+ "卷{jvan4}": 81,
83
+ "厦{sha4}": 82,
84
+ "厦{xia4}": 83,
85
+ "参{can1}": 84,
86
+ "参{cen1}": 85,
87
+ "参{shen1}": 86,
88
+ "发{fa1}": 87,
89
+ "发{fa4}": 88,
90
+ "发{fa5}": 89,
91
+ "只{zhi1}": 90,
92
+ "只{zhi3}": 91,
93
+ "号{hao2}": 92,
94
+ "号{hao4}": 93,
95
+ "号{hao5}": 94,
96
+ "同{tong2}": 95,
97
+ "同{tong4}": 96,
98
+ "同{tong5}": 97,
99
+ "吐{tu2}": 98,
100
+ "吐{tu3}": 99,
101
+ "吐{tu4}": 100,
102
+ "和{he2}": 101,
103
+ "和{he4}": 102,
104
+ "和{he5}": 103,
105
+ "和{huo2}": 104,
106
+ "和{huo4}": 105,
107
+ "和{huo5}": 106,
108
+ "喝{he1}": 107,
109
+ "喝{he4}": 108,
110
+ "圈{jvan4}": 109,
111
+ "圈{qvan1}": 110,
112
+ "圈{qvan5}": 111,
113
+ "地{de5}": 112,
114
+ "地{di4}": 113,
115
+ "地{di5}": 114,
116
+ "塞{sai1}": 115,
117
+ "塞{sai2}": 116,
118
+ "塞{sai4}": 117,
119
+ "塞{se4}": 118,
120
+ "壳{ke2}": 119,
121
+ "壳{qiao4}": 120,
122
+ "处{chu3}": 121,
123
+ "处{chu4}": 122,
124
+ "奇{ji1}": 123,
125
+ "奇{qi2}": 124,
126
+ "奔{ben1}": 125,
127
+ "奔{ben4}": 126,
128
+ "好{hao3}": 127,
129
+ "好{hao4}": 128,
130
+ "好{hao5}": 129,
131
+ "宁{ning2}": 130,
132
+ "宁{ning4}": 131,
133
+ "宁{ning5}": 132,
134
+ "宿{su4}": 133,
135
+ "宿{xiu3}": 134,
136
+ "宿{xiu4}": 135,
137
+ "将{jiang1}": 136,
138
+ "将{jiang4}": 137,
139
+ "少{shao3}": 138,
140
+ "少{shao4}": 139,
141
+ "尽{jin3}": 140,
142
+ "尽{jin4}": 141,
143
+ "岗{gang1}": 142,
144
+ "岗{gang3}": 143,
145
+ "差{cha1}": 144,
146
+ "差{cha4}": 145,
147
+ "差{chai1}": 146,
148
+ "差{ci1}": 147,
149
+ "巷{hang4}": 148,
150
+ "巷{xiang4}": 149,
151
+ "帖{tie1}": 150,
152
+ "帖{tie3}": 151,
153
+ "帖{tie4}": 152,
154
+ "干{gan1}": 153,
155
+ "干{gan4}": 154,
156
+ "应{ying1}": 155,
157
+ "应{ying4}": 156,
158
+ "应{ying5}": 157,
159
+ "度{du4}": 158,
160
+ "度{du5}": 159,
161
+ "度{duo2}": 160,
162
+ "弹{dan4}": 161,
163
+ "弹{tan2}": 162,
164
+ "弹{tan5}": 163,
165
+ "强{jiang4}": 164,
166
+ "强{qiang2}": 165,
167
+ "强{qiang3}": 166,
168
+ "当{dang1}": 167,
169
+ "当{dang4}": 168,
170
+ "当{dang5}": 169,
171
+ "待{dai1}": 170,
172
+ "待{dai4}": 171,
173
+ "得{de2}": 172,
174
+ "得{de5}": 173,
175
+ "得{dei3}": 174,
176
+ "得{dei5}": 175,
177
+ "恶{e3}": 176,
178
+ "恶{e4}": 177,
179
+ "恶{wu4}": 178,
180
+ "扁{bian3}": 179,
181
+ "扁{pian1}": 180,
182
+ "扇{shan1}": 181,
183
+ "扇{shan4}": 182,
184
+ "扎{za1}": 183,
185
+ "扎{zha1}": 184,
186
+ "扎{zha2}": 185,
187
+ "扫{sao3}": 186,
188
+ "扫{sao4}": 187,
189
+ "担{dan1}": 188,
190
+ "担{dan4}": 189,
191
+ "担{dan5}": 190,
192
+ "挑{tiao1}": 191,
193
+ "挑{tiao3}": 192,
194
+ "据{jv1}": 193,
195
+ "据{jv4}": 194,
196
+ "撒{sa1}": 195,
197
+ "撒{sa3}": 196,
198
+ "撒{sa5}": 197,
199
+ "教{jiao1}": 198,
200
+ "教{jiao4}": 199,
201
+ "散{san3}": 200,
202
+ "散{san4}": 201,
203
+ "散{san5}": 202,
204
+ "数{shu3}": 203,
205
+ "数{shu4}": 204,
206
+ "数{shu5}": 205,
207
+ "斗{dou3}": 206,
208
+ "斗{dou4}": 207,
209
+ "晃{huang3}": 208,
210
+ "曝{bao4}": 209,
211
+ "曲{qu1}": 210,
212
+ "曲{qu3}": 211,
213
+ "更{geng1}": 212,
214
+ "更{geng4}": 213,
215
+ "曾{ceng1}": 214,
216
+ "曾{ceng2}": 215,
217
+ "曾{zeng1}": 216,
218
+ "朝{chao2}": 217,
219
+ "朝{zhao1}": 218,
220
+ "朴{piao2}": 219,
221
+ "朴{pu2}": 220,
222
+ "朴{pu3}": 221,
223
+ "杆{gan1}": 222,
224
+ "杆{gan3}": 223,
225
+ "查{cha2}": 224,
226
+ "查{zha1}": 225,
227
+ "校{jiao4}": 226,
228
+ "校{xiao4}": 227,
229
+ "模{mo2}": 228,
230
+ "模{mu2}": 229,
231
+ "横{heng2}": 230,
232
+ "横{heng4}": 231,
233
+ "没{mei2}": 232,
234
+ "没{mo4}": 233,
235
+ "泡{pao1}": 234,
236
+ "泡{pao4}": 235,
237
+ "泡{pao5}": 236,
238
+ "济{ji3}": 237,
239
+ "济{ji4}": 238,
240
+ "混{hun2}": 239,
241
+ "混{hun3}": 240,
242
+ "混{hun4}": 241,
243
+ "混{hun5}": 242,
244
+ "漂{piao1}": 243,
245
+ "漂{piao3}": 244,
246
+ "漂{piao4}": 245,
247
+ "炸{zha2}": 246,
248
+ "炸{zha4}": 247,
249
+ "熟{shou2}": 248,
250
+ "熟{shu2}": 249,
251
+ "燕{yan1}": 250,
252
+ "燕{yan4}": 251,
253
+ "片{pian1}": 252,
254
+ "片{pian4}": 253,
255
+ "率{lv4}": 254,
256
+ "率{shuai4}": 255,
257
+ "畜{chu4}": 256,
258
+ "畜{xu4}": 257,
259
+ "的{de5}": 258,
260
+ "的{di1}": 259,
261
+ "的{di2}": 260,
262
+ "的{di4}": 261,
263
+ "的{di5}": 262,
264
+ "盛{cheng2}": 263,
265
+ "盛{sheng4}": 264,
266
+ "相{xiang1}": 265,
267
+ "相{xiang4}": 266,
268
+ "相{xiang5}": 267,
269
+ "省{sheng3}": 268,
270
+ "省{xing3}": 269,
271
+ "看{kan1}": 270,
272
+ "看{kan4}": 271,
273
+ "看{kan5}": 272,
274
+ "着{zhao1}": 273,
275
+ "着{zhao2}": 274,
276
+ "着{zhao5}": 275,
277
+ "着{zhe5}": 276,
278
+ "着{zhuo2}": 277,
279
+ "着{zhuo5}": 278,
280
+ "矫{jiao3}": 279,
281
+ "禁{jin1}": 280,
282
+ "禁{jin4}": 281,
283
+ "种{zhong3}": 282,
284
+ "种{zhong4}": 283,
285
+ "称{chen4}": 284,
286
+ "称{cheng1}": 285,
287
+ "空{kong1}": 286,
288
+ "空{kong4}": 287,
289
+ "答{da1}": 288,
290
+ "答{da2}": 289,
291
+ "粘{nian2}": 290,
292
+ "粘{zhan1}": 291,
293
+ "糊{hu2}": 292,
294
+ "糊{hu5}": 293,
295
+ "系{ji4}": 294,
296
+ "系{xi4}": 295,
297
+ "系{xi5}": 296,
298
+ "累{lei2}": 297,
299
+ "累{lei3}": 298,
300
+ "累{lei4}": 299,
301
+ "累{lei5}": 300,
302
+ "纤{qian4}": 301,
303
+ "纤{xian1}": 302,
304
+ "结{jie1}": 303,
305
+ "结{jie2}": 304,
306
+ "结{jie5}": 305,
307
+ "给{gei3}": 306,
308
+ "给{gei5}": 307,
309
+ "给{ji3}": 308,
310
+ "缝{feng2}": 309,
311
+ "缝{feng4}": 310,
312
+ "缝{feng5}": 311,
313
+ "肖{xiao1}": 312,
314
+ "肖{xiao4}": 313,
315
+ "背{bei1}": 314,
316
+ "背{bei4}": 315,
317
+ "脏{zang1}": 316,
318
+ "脏{zang4}": 317,
319
+ "舍{she3}": 318,
320
+ "舍{she4}": 319,
321
+ "色{se4}": 320,
322
+ "色{shai3}": 321,
323
+ "落{lao4}": 322,
324
+ "落{luo4}": 323,
325
+ "蒙{meng1}": 324,
326
+ "蒙{meng2}": 325,
327
+ "蒙{meng3}": 326,
328
+ "薄{bao2}": 327,
329
+ "薄{bo2}": 328,
330
+ "薄{bo4}": 329,
331
+ "藏{cang2}": 330,
332
+ "藏{zang4}": 331,
333
+ "血{xie3}": 332,
334
+ "血{xue4}": 333,
335
+ "行{hang2}": 334,
336
+ "行{hang5}": 335,
337
+ "行{heng5}": 336,
338
+ "行{xing2}": 337,
339
+ "行{xing4}": 338,
340
+ "要{yao1}": 339,
341
+ "要{yao4}": 340,
342
+ "观{guan1}": 341,
343
+ "观{guan4}": 342,
344
+ "觉{jiao4}": 343,
345
+ "觉{jiao5}": 344,
346
+ "觉{jve2}": 345,
347
+ "角{jiao3}": 346,
348
+ "角{jve2}": 347,
349
+ "解{jie3}": 348,
350
+ "解{jie4}": 349,
351
+ "解{xie4}": 350,
352
+ "说{shui4}": 351,
353
+ "说{shuo1}": 352,
354
+ "调{diao4}": 353,
355
+ "调{tiao2}": 354,
356
+ "踏{ta1}": 355,
357
+ "踏{ta4}": 356,
358
+ "车{che1}": 357,
359
+ "车{jv1}": 358,
360
+ "转{zhuan3}": 359,
361
+ "转{zhuan4}": 360,
362
+ "载{zai3}": 361,
363
+ "载{zai4}": 362,
364
+ "还{hai2}": 363,
365
+ "还{huan2}": 364,
366
+ "遂{sui2}": 365,
367
+ "遂{sui4}": 366,
368
+ "都{dou1}": 367,
369
+ "都{du1}": 368,
370
+ "重{chong2}": 369,
371
+ "重{zhong4}": 370,
372
+ "量{liang2}": 371,
373
+ "量{liang4}": 372,
374
+ "量{liang5}": 373,
375
+ "钻{zuan1}": 374,
376
+ "钻{zuan4}": 375,
377
+ "铺{pu1}": 376,
378
+ "铺{pu4}": 377,
379
+ "长{chang2}": 378,
380
+ "长{chang3}": 379,
381
+ "长{zhang3}": 380,
382
+ "间{jian1}": 381,
383
+ "间{jian4}": 382,
384
+ "降{jiang4}": 383,
385
+ "降{xiang2}": 384,
386
+ "难{nan2}": 385,
387
+ "难{nan4}": 386,
388
+ "难{nan5}": 387,
389
+ "露{lou4}": 388,
390
+ "露{lu4}": 389,
391
+ "鲜{xian1}": 390,
392
+ "鲜{xian3}": 391
393
+ }