nithinraok commited on
Commit
d378233
·
verified ·
1 Parent(s): 06aadb5

Update to support zero GPU (#5)

Browse files

- update model space (a7d8cee5fdad58415a5ac619a8666e61ac73d5ee)
- remove yt support (e6218d687975e14299f0c7a0560f668283b5cabc)

Files changed (3) hide show
  1. app.py +14 -78
  2. nemo_align.py +7 -8
  3. requirements.txt +1 -1
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import subprocess
3
  import torch
4
  import gradio as gr
5
- import yt_dlp
6
  import pandas as pd
7
  from nemo.collections.asr.models import ASRModel
8
  from nemo_align import align_tdt_to_ctc_timestamps
@@ -12,6 +12,7 @@ import os
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  def process_audio(input_file, output_file):
 
15
  command = [
16
  'sox', input_file,
17
  output_file,
@@ -20,6 +21,7 @@ def process_audio(input_file, output_file):
20
  ]
21
  try:
22
  subprocess.run(command, check=True)
 
23
  return output_file
24
  except:
25
  raise gr.Error("Failed to convert audio to single channel and sampling rate to 16000")
@@ -38,55 +40,8 @@ def get_dataframe_segments(segments):
38
  return df
39
 
40
 
41
- def get_video_info(url):
42
- ydl_opts = {
43
- 'quiet': True,
44
- 'skip-download': True,
45
- }
46
-
47
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
48
- try:
49
- info = ydl.extract_info(url, download=False)
50
- except:
51
- raise gr.Error("Failed to extract video info from Youtube")
52
- return info
53
-
54
- def download_audio(url):
55
- ydl_opts = {
56
- 'format': 'bestaudio/best,channels:1',
57
- 'quiet': True,
58
- 'outtmpl': 'audio_file',
59
- 'postprocessors': [{
60
- 'key': 'FFmpegExtractAudio',
61
- 'preferredcodec': 'flac',
62
- 'preferredquality': '192',
63
- }],
64
- }
65
-
66
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
67
- try:
68
- ydl.download([url])
69
- except yt_dlp.utils.DownloadError as err:
70
- raise gr.Error(str(err))
71
-
72
- return process_audio('audio_file.flac', 'processed_file.flac')
73
-
74
-
75
- def get_audio_from_youtube(url):
76
- info = get_video_info(url)
77
- duration = info.get('duration', 0) # Duration in seconds
78
- video_id = info.get('id',None)
79
-
80
- html = f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
81
-
82
- if duration > 2*60*60: # 2 hrs change later based on GPU
83
- return gr.Error(str("For GPU {}, single pass maximum audio can be 2hrs"))
84
- else:
85
- return download_audio(url), html
86
-
87
-
88
  def get_transcripts(audio_path, model):
89
- with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
90
  with torch.inference_mode():
91
  text = model.transcribe(audio=[audio_path], )
92
  return text
@@ -101,31 +56,20 @@ def pick_asr_model():
101
 
102
  asr_model = pick_asr_model()
103
 
104
- def run_nemo_models(url, microphone, audio_path):
105
- html = None
106
- if url is None or len(url)<2:
107
- path1 = microphone if microphone else audio_path
108
- else:
109
- gr.Info("Downloading and processing audio from Youtube")
110
- path1, html = get_audio_from_youtube(url)
111
 
112
  gr.Info("Running NeMo Model")
113
- text = get_transcripts(path1, asr_model)
114
 
115
- segments = align_tdt_to_ctc_timestamps(text, asr_model, path1)
116
 
117
  df = get_dataframe_segments(segments)
118
 
119
- return df, html
120
-
121
- def clear_youtube_link():
122
- # Remove .flac files in current directory
123
- file_list = os.listdir()
124
- for file in file_list:
125
- if file.endswith(".flac"):
126
- os.remove(file)
127
-
128
- return None
129
 
130
 
131
  # def run_speaker_diarization()
@@ -143,17 +87,12 @@ with gr.Blocks(
143
  ) as demo:
144
  gr.HTML("<h1 style='text-align: center'>Transcription with timestamps using Parakeet TDT-CTC</h1>")
145
  gr.Markdown('''
146
- Choose between different sources of audio (Microphone, Audio File, Youtube Video) to transcribe along with timestamps.
147
 
148
  Parakeet models with limited attention are quite fast due to their limited attention mechanism. The current model with 1.1B parameters can transcribe very long audios upto 11 hrs on A6000 GPU in a single pass.
149
 
150
  Model used: [nvidia/parakeet-tdt_ctc-1.1b](https://huggingface.co/nvidia/parakeet-tdt_ctc-1.1b).
151
  ''')
152
- # This block is for reading audio from MIC
153
- with gr.Tab('Audio from Youtube'):
154
- with gr.Row():
155
- yt_link = gr.Textbox(value=None, label='Enter Youtube Link', type='text')
156
- yt_render = gr.HTML()
157
 
158
  with gr.Tab('Audio From File'):
159
  file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath')
@@ -173,10 +112,7 @@ with gr.Blocks(
173
  time_stamp = gr.DataFrame(wrap=True, label='Speech Recognition with TimeStamps',
174
  row_count=(1, "dynamic"), headers=['start_time', 'end_time', 'text'])
175
 
176
- # b1.click(run_nemo_models, inputs=[file_input, mic_input, yt_link], outputs=[text_output, yt_render])
177
-
178
- b2.click(run_nemo_models, inputs=[yt_link, file_input, mic_input], outputs=[time_stamp, yt_render]).then(
179
- clear_youtube_link, None, yt_link, queue=False) #here clean up passing None to audio.
180
 
181
  demo.queue(True)
182
  demo.launch(share=True, debug=True)
 
2
  import subprocess
3
  import torch
4
  import gradio as gr
5
+ import spaces
6
  import pandas as pd
7
  from nemo.collections.asr.models import ASRModel
8
  from nemo_align import align_tdt_to_ctc_timestamps
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  def process_audio(input_file, output_file):
15
+ gr.Info("Processing audio to single channel and sampling rate to 16000")
16
  command = [
17
  'sox', input_file,
18
  output_file,
 
21
  ]
22
  try:
23
  subprocess.run(command, check=True)
24
+ gr.Info("Audio processed successfully")
25
  return output_file
26
  except:
27
  raise gr.Error("Failed to convert audio to single channel and sampling rate to 16000")
 
40
  return df
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def get_transcripts(audio_path, model):
44
+ with torch.amp.autocast(device, dtype=torch.bfloat16, enabled=True):
45
  with torch.inference_mode():
46
  text = model.transcribe(audio=[audio_path], )
47
  return text
 
56
 
57
  asr_model = pick_asr_model()
58
 
59
+ @spaces.GPU
60
+ def run_nemo_models(microphone, audio_path):
61
+ path1 = microphone if microphone else audio_path
62
+
63
+ new_path = process_audio(path1, "processed_audio.flac")
 
 
64
 
65
  gr.Info("Running NeMo Model")
66
+ text = get_transcripts(new_path, asr_model)
67
 
68
+ segments = align_tdt_to_ctc_timestamps(text, asr_model, new_path)
69
 
70
  df = get_dataframe_segments(segments)
71
 
72
+ return df
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  # def run_speaker_diarization()
 
87
  ) as demo:
88
  gr.HTML("<h1 style='text-align: center'>Transcription with timestamps using Parakeet TDT-CTC</h1>")
89
  gr.Markdown('''
90
+ Choose between different sources of audio (Microphone, Audio File) to transcribe along with timestamps.
91
 
92
  Parakeet models with limited attention are quite fast due to their limited attention mechanism. The current model with 1.1B parameters can transcribe very long audios upto 11 hrs on A6000 GPU in a single pass.
93
 
94
  Model used: [nvidia/parakeet-tdt_ctc-1.1b](https://huggingface.co/nvidia/parakeet-tdt_ctc-1.1b).
95
  ''')
 
 
 
 
 
96
 
97
  with gr.Tab('Audio From File'):
98
  file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath')
 
112
  time_stamp = gr.DataFrame(wrap=True, label='Speech Recognition with TimeStamps',
113
  row_count=(1, "dynamic"), headers=['start_time', 'end_time', 'text'])
114
 
115
+ b2.click(run_nemo_models, inputs=[file_input, mic_input], outputs=[time_stamp])
 
 
 
116
 
117
  demo.queue(True)
118
  demo.launch(share=True, debug=True)
nemo_align.py CHANGED
@@ -6,7 +6,7 @@ from nemo.utils import logging
6
  from pathlib import Path
7
  from viterbi_decoding import viterbi_decoding
8
  from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
9
-
10
  BLANK_TOKEN = "<b>"
11
 
12
  SPACE_TOKEN = "<space>"
@@ -435,9 +435,9 @@ def get_start_end_for_segments(word_timestamps):
435
 
436
  return segment_timestamps
437
 
438
-
439
  def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
440
- tdt_txt = tdt_txt[0][0] if tdt_txt is not None else tdt_txt
441
  if isinstance(model, EncDecHybridRNNTCTCModel):
442
  ctc_cfg = CTCDecodingConfig()
443
  ctc_cfg.decoding = "greedy_batch"
@@ -445,12 +445,11 @@ def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
445
  else:
446
  raise ValueError("Currently supporting hybrid models")
447
 
448
- if torch.cuda.is_available():
449
- viterbi_device = torch.device('cuda')
450
- else:
451
- viterbi_device = torch.device('cpu')
452
 
453
- with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
454
  with torch.inference_mode():
455
  hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1)
456
 
 
6
  from pathlib import Path
7
  from viterbi_decoding import viterbi_decoding
8
  from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
9
+ import spaces
10
  BLANK_TOKEN = "<b>"
11
 
12
  SPACE_TOKEN = "<space>"
 
435
 
436
  return segment_timestamps
437
 
438
+ @spaces.GPU
439
  def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
440
+ tdt_txt = tdt_txt[0].text if tdt_txt is not None else tdt_txt
441
  if isinstance(model, EncDecHybridRNNTCTCModel):
442
  ctc_cfg = CTCDecodingConfig()
443
  ctc_cfg.decoding = "greedy_batch"
 
445
  else:
446
  raise ValueError("Currently supporting hybrid models")
447
 
448
+ device = "cuda" if torch.cuda.is_available() else "cpu"
449
+
450
+ viterbi_device = torch.device(device)
 
451
 
452
+ with torch.amp.autocast(device_type=device, dtype=torch.bfloat16, enabled=True):
453
  with torch.inference_mode():
454
  hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1)
455
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  Cython
2
  packaging
3
- git+https://github.com/NVIDIA/NeMo.git@r2.0.0#egg=nemo_toolkit[asr]
4
  yt_dlp
 
1
  Cython
2
  packaging
3
+ git+https://github.com/NVIDIA/NeMo.git@r2.5.3#egg=nemo_toolkit[asr]
4
  yt_dlp