Spaces:
Running
on
Zero
Running
on
Zero
| from nemo.collections.asr.models import EncDecHybridRNNTCTCModel | |
| from dataclasses import dataclass, field | |
| from typing import List, Union | |
| import torch | |
| from nemo.utils import logging | |
| from pathlib import Path | |
| from viterbi_decoding import viterbi_decoding | |
| from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig | |
| import spaces | |
| BLANK_TOKEN = "<b>" | |
| SPACE_TOKEN = "<space>" | |
| V_NEGATIVE_NUM = -3.4e38 | |
| class Token: | |
| text: str = None | |
| text_cased: str = None | |
| s_start: int = None | |
| s_end: int = None | |
| t_start: float = None | |
| t_end: float = None | |
| class Word: | |
| text: str = None | |
| s_start: int = None | |
| s_end: int = None | |
| t_start: float = None | |
| t_end: float = None | |
| tokens: List[Token] = field(default_factory=list) | |
| class Segment: | |
| text: str = None | |
| s_start: int = None | |
| s_end: int = None | |
| t_start: float = None | |
| t_end: float = None | |
| words_and_tokens: List[Union[Word, Token]] = field(default_factory=list) | |
| class Utterance: | |
| token_ids_with_blanks: List[int] = field(default_factory=list) | |
| segments_and_tokens: List[Union[Segment, Token]] = field(default_factory=list) | |
| text: str = None | |
| pred_text: str = None | |
| audio_filepath: str = None | |
| utt_id: str = None | |
| saved_output_files: dict = field(default_factory=dict) | |
| def is_sub_or_superscript_pair(ref_text, text): | |
| """returns True if ref_text is a subscript or superscript version of text""" | |
| sub_or_superscript_to_num = { | |
| "β°": "0", | |
| "ΒΉ": "1", | |
| "Β²": "2", | |
| "Β³": "3", | |
| "β΄": "4", | |
| "β΅": "5", | |
| "βΆ": "6", | |
| "β·": "7", | |
| "βΈ": "8", | |
| "βΉ": "9", | |
| "β": "0", | |
| "β": "1", | |
| "β": "2", | |
| "β": "3", | |
| "β": "4", | |
| "β ": "5", | |
| "β": "6", | |
| "β": "7", | |
| "β": "8", | |
| "β": "9", | |
| } | |
| if text in sub_or_superscript_to_num: | |
| if sub_or_superscript_to_num[text] == ref_text: | |
| return True | |
| return False | |
| def restore_token_case(word, word_tokens): | |
| # remove repeated "β" and "_" from word as that is what the tokenizer will do | |
| while "ββ" in word: | |
| word = word.replace("ββ", "β") | |
| while "__" in word: | |
| word = word.replace("__", "_") | |
| word_tokens_cased = [] | |
| word_char_pointer = 0 | |
| for token in word_tokens: | |
| token_cased = "" | |
| for token_char in token: | |
| if token_char == word[word_char_pointer]: | |
| token_cased += token_char | |
| word_char_pointer += 1 | |
| else: | |
| if token_char.upper() == word[word_char_pointer] or is_sub_or_superscript_pair( | |
| token_char, word[word_char_pointer] | |
| ): | |
| token_cased += token_char.upper() | |
| word_char_pointer += 1 | |
| else: | |
| if token_char == "β" or token_char == "_": | |
| if word[word_char_pointer] == "β" or word[word_char_pointer] == "_": | |
| token_cased += token_char | |
| word_char_pointer += 1 | |
| elif word_char_pointer == 0: | |
| token_cased += token_char | |
| else: | |
| raise RuntimeError( | |
| f"Unexpected error - failed to recover capitalization of tokens for word {word}" | |
| ) | |
| word_tokens_cased.append(token_cased) | |
| return word_tokens_cased | |
| def get_utt_obj( | |
| text, model, separator, T, audio_filepath, utt_id, | |
| ): | |
| """ | |
| Function to create an Utterance object and add all necessary information to it except | |
| for timings of the segments / words / tokens according to the alignment - that will | |
| be done later in a different function, after the alignment is done. | |
| The Utterance object has a list segments_and_tokens which contains Segment objects and | |
| Token objects (for blank tokens in between segments). | |
| Within the Segment objects, there is a list words_and_tokens which contains Word objects and | |
| Token objects (for blank tokens in between words). | |
| Within the Word objects, there is a list tokens tokens which contains Token objects for | |
| blank and non-blank tokens. | |
| We will be building up these lists in this function. This data structure will then be useful for | |
| generating the various output files that we wish to save. | |
| """ | |
| if not separator: # if separator is not defined - treat the whole text as one segment | |
| segments = [text] | |
| else: | |
| segments = text.split(separator) | |
| # remove any spaces at start and end of segments | |
| segments = [seg.strip() for seg in segments] | |
| # remove any empty segments | |
| segments = [seg for seg in segments if len(seg) > 0] | |
| utt = Utterance(text=text, audio_filepath=audio_filepath, utt_id=utt_id,) | |
| # build up lists: token_ids_with_blanks, segments_and_tokens. | |
| # The code for these is different depending on whether we use char-based tokens or not | |
| if hasattr(model, 'tokenizer'): | |
| if hasattr(model, 'blank_id'): | |
| BLANK_ID = model.blank_id | |
| else: | |
| BLANK_ID = len(model.tokenizer.vocab) # TODO: check | |
| utt.token_ids_with_blanks = [BLANK_ID] | |
| # check for text being 0 length | |
| if len(text) == 0: | |
| return utt | |
| # check for # tokens + token repetitions being > T | |
| all_tokens = model.tokenizer.text_to_ids(text) | |
| n_token_repetitions = 0 | |
| for i_tok in range(1, len(all_tokens)): | |
| if all_tokens[i_tok] == all_tokens[i_tok - 1]: | |
| n_token_repetitions += 1 | |
| if len(all_tokens) + n_token_repetitions > T: | |
| logging.info( | |
| f"Utterance {utt_id} has too many tokens compared to the audio file duration." | |
| " Will not generate output alignment files for this utterance." | |
| ) | |
| return utt | |
| # build up data structures containing segments/words/tokens | |
| utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,)) | |
| segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank | |
| word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank | |
| for segment in segments: | |
| # add the segment to segment_info and increment the segment_s_pointer | |
| segment_tokens = model.tokenizer.text_to_tokens(segment) | |
| utt.segments_and_tokens.append( | |
| Segment( | |
| text=segment, | |
| s_start=segment_s_pointer, | |
| # segment_tokens do not contain blanks => need to muliply by 2 | |
| # s_end needs to be the index of the final token (including blanks) of the current segment: | |
| # segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment => | |
| # => need to subtract 2 | |
| s_end=segment_s_pointer + len(segment_tokens) * 2 - 2, | |
| ) | |
| ) | |
| segment_s_pointer += ( | |
| len(segment_tokens) * 2 | |
| ) # multiply by 2 to account for blanks (which are not present in segment_tokens) | |
| words = segment.split(" ") # we define words to be space-separated sub-strings | |
| for word_i, word in enumerate(words): | |
| word_tokens = model.tokenizer.text_to_tokens(word) | |
| word_token_ids = model.tokenizer.text_to_ids(word) | |
| word_tokens_cased = restore_token_case(word, word_tokens) | |
| # add the word to word_info and increment the word_s_pointer | |
| utt.segments_and_tokens[-1].words_and_tokens.append( | |
| # word_tokens do not contain blanks => need to muliply by 2 | |
| # s_end needs to be the index of the final token (including blanks) of the current word: | |
| # word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word => | |
| # => need to subtract 2 | |
| Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2) | |
| ) | |
| word_s_pointer += ( | |
| len(word_tokens) * 2 | |
| ) # multiply by 2 to account for blanks (which are not present in word_tokens) | |
| for token_i, (token, token_id, token_cased) in enumerate( | |
| zip(word_tokens, word_token_ids, word_tokens_cased) | |
| ): | |
| # add the text tokens and the blanks in between them | |
| # to our token-based variables | |
| utt.token_ids_with_blanks.extend([token_id, BLANK_ID]) | |
| # adding Token object for non-blank token | |
| utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append( | |
| Token( | |
| text=token, | |
| text_cased=token_cased, | |
| # utt.token_ids_with_blanks has the form [...., <this non-blank token>, <blank>] => | |
| # => if do len(utt.token_ids_with_blanks) - 1 you get the index of the final <blank> | |
| # => we want to do len(utt.token_ids_with_blanks) - 2 to get the index of <this non-blank token> | |
| s_start=len(utt.token_ids_with_blanks) - 2, | |
| # s_end is same as s_start since the token only occupies one element in the list | |
| s_end=len(utt.token_ids_with_blanks) - 2, | |
| ) | |
| ) | |
| # adding Token object for blank tokens in between the tokens of the word | |
| # (ie do not add another blank if you have reached the end) | |
| if token_i < len(word_tokens) - 1: | |
| utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append( | |
| Token( | |
| text=BLANK_TOKEN, | |
| text_cased=BLANK_TOKEN, | |
| # utt.token_ids_with_blanks has the form [...., <this blank token>] => | |
| # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank> | |
| s_start=len(utt.token_ids_with_blanks) - 1, | |
| # s_end is same as s_start since the token only occupies one element in the list | |
| s_end=len(utt.token_ids_with_blanks) - 1, | |
| ) | |
| ) | |
| # add a Token object for blanks in between words in this segment | |
| # (but only *in between* - do not add the token if it is after the final word) | |
| if word_i < len(words) - 1: | |
| utt.segments_and_tokens[-1].words_and_tokens.append( | |
| Token( | |
| text=BLANK_TOKEN, | |
| text_cased=BLANK_TOKEN, | |
| # utt.token_ids_with_blanks has the form [...., <this blank token>] => | |
| # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank> | |
| s_start=len(utt.token_ids_with_blanks) - 1, | |
| # s_end is same as s_start since the token only occupies one element in the list | |
| s_end=len(utt.token_ids_with_blanks) - 1, | |
| ) | |
| ) | |
| # add the blank token in between segments/after the final segment | |
| utt.segments_and_tokens.append( | |
| Token( | |
| text=BLANK_TOKEN, | |
| text_cased=BLANK_TOKEN, | |
| # utt.token_ids_with_blanks has the form [...., <this blank token>] => | |
| # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank> | |
| s_start=len(utt.token_ids_with_blanks) - 1, | |
| # s_end is same as s_start since the token only occupies one element in the list | |
| s_end=len(utt.token_ids_with_blanks) - 1, | |
| ) | |
| ) | |
| return utt | |
| def _get_utt_id(audio_filepath, audio_filepath_parts_in_utt_id): | |
| fp_parts = Path(audio_filepath).parts[-audio_filepath_parts_in_utt_id:] | |
| utt_id = Path("_".join(fp_parts)).stem | |
| utt_id = utt_id.replace(" ", "-") # replace any spaces in the filepath with dashes | |
| return utt_id | |
| def add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration): | |
| """ | |
| Function to add t_start and t_end (representing time in seconds) to the Utterance object utt_obj. | |
| Args: | |
| utt_obj: Utterance object to which we will add t_start and t_end for its | |
| constituent segments/words/tokens. | |
| alignment_utt: a list of ints indicating which token does the alignment pass through at each | |
| timestep (will take the form [0, 0, 1, 1, ..., <num of tokens including blanks in uterance>]). | |
| output_timestep_duration: a float indicating the duration of a single output timestep from | |
| the ASR Model. | |
| Returns: | |
| utt_obj: updated Utterance object. | |
| """ | |
| # General idea for the algorithm of how we add t_start and t_end | |
| # the timestep where a token s starts is the location of the first appearance of s_start in alignment_utt | |
| # the timestep where a token s ends is the location of the final appearance of s_end in alignment_utt | |
| # We will make dictionaries num_to_first_alignment_appearance and | |
| # num_to_last_appearance and use that to update all of | |
| # the t_start and t_end values in utt_obj. | |
| # We will put t_start = t_end = -1 for tokens that are skipped (should only be blanks) | |
| num_to_first_alignment_appearance = dict() | |
| num_to_last_alignment_appearance = dict() | |
| prev_s = -1 # use prev_s to keep track of when the s changes | |
| for t, s in enumerate(alignment_utt): | |
| if s > prev_s: | |
| num_to_first_alignment_appearance[s] = t | |
| if prev_s >= 0: # dont record prev_s = -1 | |
| num_to_last_alignment_appearance[prev_s] = t - 1 | |
| prev_s = s | |
| # add last appearance of the final s | |
| num_to_last_alignment_appearance[prev_s] = len(alignment_utt) - 1 | |
| # update all the t_start and t_end in utt_obj | |
| for segment_or_token in utt_obj.segments_and_tokens: | |
| if type(segment_or_token) is Segment: | |
| segment = segment_or_token | |
| segment.t_start = num_to_first_alignment_appearance[segment.s_start] * output_timestep_duration | |
| segment.t_end = (num_to_last_alignment_appearance[segment.s_end] + 1) * output_timestep_duration | |
| for word_or_token in segment.words_and_tokens: | |
| if type(word_or_token) is Word: | |
| word = word_or_token | |
| word.t_start = num_to_first_alignment_appearance[word.s_start] * output_timestep_duration | |
| word.t_end = (num_to_last_alignment_appearance[word.s_end] + 1) * output_timestep_duration | |
| for token in word.tokens: | |
| if token.s_start in num_to_first_alignment_appearance: | |
| token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration | |
| else: | |
| token.t_start = -1 | |
| if token.s_end in num_to_last_alignment_appearance: | |
| token.t_end = ( | |
| num_to_last_alignment_appearance[token.s_end] + 1 | |
| ) * output_timestep_duration | |
| else: | |
| token.t_end = -1 | |
| else: | |
| token = word_or_token | |
| if token.s_start in num_to_first_alignment_appearance: | |
| token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration | |
| else: | |
| token.t_start = -1 | |
| if token.s_end in num_to_last_alignment_appearance: | |
| token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration | |
| else: | |
| token.t_end = -1 | |
| else: | |
| token = segment_or_token | |
| if token.s_start in num_to_first_alignment_appearance: | |
| token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration | |
| else: | |
| token.t_start = -1 | |
| if token.s_end in num_to_last_alignment_appearance: | |
| token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration | |
| else: | |
| token.t_end = -1 | |
| return utt_obj | |
| def get_word_timings( | |
| alignment_level, utt_obj, | |
| ): | |
| boundary_info_utt = [] | |
| for segment_or_token in utt_obj.segments_and_tokens: | |
| if type(segment_or_token) is Segment: | |
| segment = segment_or_token | |
| for word_or_token in segment.words_and_tokens: | |
| if type(word_or_token) is Word: | |
| word = word_or_token | |
| if alignment_level == "words": | |
| boundary_info_utt.append(word) | |
| word_timestamps=[] | |
| for boundary_info_ in boundary_info_utt: # loop over every token/word/segment | |
| # skip if t_start = t_end = negative number because we used it as a marker to skip some blank tokens | |
| if not (boundary_info_.t_start < 0 or boundary_info_.t_end < 0): | |
| text = boundary_info_.text | |
| start_time = boundary_info_.t_start | |
| end_time = boundary_info_.t_end | |
| text = text.replace(" ", SPACE_TOKEN) | |
| word_timestamps.append((text, start_time, end_time)) | |
| return word_timestamps | |
| def get_start_end_for_segments(word_timestamps): | |
| segment_timestamps=[] | |
| word_list = [] | |
| beginning = None | |
| for word, start, end in word_timestamps: | |
| if beginning is None: | |
| beginning = start | |
| word = word.capitalize() | |
| word_list.append(word) | |
| if word.endswith('.') or word.endswith('?') or word.endswith('!'): | |
| segment = ' '.join(word_list) | |
| segment_timestamps.append((segment, beginning, end)) | |
| beginning = None | |
| word_list = [] | |
| if word_list: # Only append if there are remaining words | |
| segment = ' '.join(word_list) | |
| segment_timestamps.append((segment, beginning, end)) | |
| return segment_timestamps | |
| def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath): | |
| tdt_txt = tdt_txt[0].text if tdt_txt is not None else tdt_txt | |
| if isinstance(model, EncDecHybridRNNTCTCModel): | |
| ctc_cfg = CTCDecodingConfig() | |
| ctc_cfg.decoding = "greedy_batch" | |
| model.change_decoding_strategy(decoding_cfg=ctc_cfg, decoder_type="ctc") | |
| else: | |
| raise ValueError("Currently supporting hybrid models") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| viterbi_device = torch.device(device) | |
| with torch.amp.autocast(device_type=device, dtype=torch.bfloat16, enabled=True): | |
| with torch.inference_mode(): | |
| hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1) | |
| if type(hypotheses) == tuple and len(hypotheses) == 2: | |
| hypotheses = hypotheses[0] | |
| log_probs_list_batch = [hypotheses[0].y_sequence] | |
| T_list_batch = [hypotheses[0].y_sequence.shape[0]] | |
| ctc_pred_text = hypotheses[0].text if tdt_txt is None else tdt_txt | |
| utt_obj = get_utt_obj( | |
| ctc_pred_text, | |
| model, | |
| None, | |
| T_list_batch[0], | |
| audio_filepath, | |
| _get_utt_id(audio_filepath, 1), | |
| ) | |
| utt_obj.pred_text = ctc_pred_text | |
| y_list_batch = [utt_obj.token_ids_with_blanks] | |
| U_list_batch = [len(utt_obj.token_ids_with_blanks)] | |
| if hasattr(model, 'tokenizer'): | |
| V = len(model.tokenizer.vocab) + 1 | |
| else: | |
| V = len(model.decoder.vocabulary) + 1 | |
| # turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding | |
| T_max = max(T_list_batch) | |
| U_max = max(U_list_batch) | |
| # V = the number of tokens in the vocabulary + 1 for the blank token. | |
| if hasattr(model, 'tokenizer'): | |
| V = len(model.tokenizer.vocab) + 1 | |
| else: | |
| V = len(model.decoder.vocabulary) + 1 | |
| T_batch = torch.tensor(T_list_batch) | |
| U_batch = torch.tensor(U_list_batch) | |
| # make log_probs_batch tensor of shape (B x T_max x V) | |
| log_probs_batch = V_NEGATIVE_NUM * torch.ones((1, T_max, V)) | |
| for b, log_probs_utt in enumerate(log_probs_list_batch): | |
| t = log_probs_utt.shape[0] | |
| log_probs_batch[b, :t, :] = log_probs_utt | |
| y_batch = V * torch.ones((1, U_max), dtype=torch.int64) | |
| for b, y_utt in enumerate(y_list_batch): | |
| U_utt = U_batch[b] | |
| y_batch[b, :U_utt] = torch.tensor(y_utt) | |
| model_downsample_factor = 8 | |
| output_timestep_duration = ( | |
| model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate | |
| ) | |
| alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) | |
| utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignments_batch[0], output_timestep_duration) | |
| word_timestamps = get_word_timings("words", utt_obj=utt_obj) | |
| segmet_timestamps = get_start_end_for_segments(word_timestamps) | |
| return segmet_timestamps | |
| # def main(): | |
| # # model = 'nvidia/parakeet-tdt_ctc-1.1b.nemo' | |
| # # from nemo.collections.asr.models import ASRModel | |
| # # asr_model = ASRModel.from_pretrained(model).to('cuda') | |
| # # asr_model.eval() | |
| # # Segments = align_tdt_to_ctc_timestamps(None, asr_model, 'processed_file.flac') | |
| # if __name__ == '__main__': | |
| # main() |