jaeikkim
Reinit Space without binary assets
7bfbdc3
# coding=utf-8
# Copyright 2025 AIDAS Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import torch
import wandb
from models import MMadaModelLM
from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer
from training.prompting_utils import UniversalPrompting
from training.utils import get_config, flatten_omega_conf
from transformers import AutoTokenizer
import argparse
# from models.modeling_speech_tokenizer import EMOVASpeechTokenizer
def resize_vocab(model, config):
print(f"Resizing token embeddings to {config.model.mmada.new_vocab_size}")
model.resize_token_embeddings(config.model.mmada.new_vocab_size)
def get_vq_model_class(model_type):
if model_type == "magvitv2":
return MAGVITv2
elif model_type == "emova":
return EMOVASpeechTokenizer.from_pretrained(
"Emova-ollm/emova_speech_tokenizer_hf"
)
else:
raise ValueError(f"model_type {model_type} not supported.")
if __name__ == '__main__':
config = get_config()
resume_wandb_run = config.wandb.resume
run_id = config.wandb.get("run_id", None)
if run_id is None:
resume_wandb_run = False
run_id = wandb.util.generate_id()
config.wandb.run_id = run_id
wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
wandb.init(
project="demo",
name=config.experiment.name + '_stt',
config=wandb_config,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_tokenizer = AutoTokenizer.from_pretrained(config.model.mmada.pretrained_model_path, padding_side="left")
uni_prompting = UniversalPrompting(text_tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
special_tokens=("<|s2t|>", "<|soa|>", "<|eoa|>", "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob, use_reserved_token=True)
vq_model = get_vq_model_class(config.model.speech_model.type)
vq_model = vq_model.from_pretrained(config.model.speech_model.speech_model_name).to(device)
vq_model.requires_grad_(False)
vq_model.eval()
quantizer = vq_model.encoder.quantizer
if hasattr(quantizer, 'codebook_size'):
print("Codebook size:", quantizer.codebook_size)
# 2) codebook ์ž„๋ฒ ๋”ฉ ๋งคํŠธ๋ฆญ์Šค๋กœ๋ถ€ํ„ฐ shape ์ถ”์ถœ
elif hasattr(quantizer, 'codebook'):
cb = quantizer.codebook # nn.Embedding ํ˜•ํƒœ์ผ ๊ฐ€๋Šฅ์„ฑ
print("Codebook size:", cb.weight.shape[0])
# 3) FSQ์ธ ๊ฒฝ์šฐ levels ๋กœ ์–‘์žํ™” ๋‹จ๊ณ„ ์ˆ˜ ํ™•์ธ
elif hasattr(quantizer, 'levels'):
levels = quantizer.levels
print("Quantization levels per group:", levels)
print("Total scalar bins:", sum(levels))
else:
raise RuntimeError("Quantizer์— codebook ์ •๋ณด๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
sys.exit()
# model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
# c) Load main MMaDA model
# train_step = config.model.mmada.train_step
trained_checkpoint_path = f"/home/work/AIDAS/omada-training-stage1/checkpoint-10000/unwrapped_model/"
print(f"Loading trained model from: {trained_checkpoint_path}")
model = MMadaModelLM.from_pretrained(
trained_checkpoint_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
config='/home/work/AIDAS/ommda-training-s2t-mmada/config.json'
)
print("โœ… Trained model loaded successfully!")
# model = MMadaModelLM.from_pretrained(config.model.mmada.pretrained_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
# # d) Extend vocabulary for speech tokens
num_speech_tokens = 4096
image_vocab_size = config.model.mmada.codebook_size # 8192
text_vocab_size = len(uni_prompting.text_tokenizer)
# resize_vocab(model, config)
model.to(device)
mask_token_id = model.config.mask_token_id
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 1 # retain only the top_k most likely tokens, clamp others to have 0 probability
audio_file_list = os.listdir(config.audio_dir)
audio_file_list = [f for f in audio_file_list if f.lower().endswith(('.wav', '.flac', '.mp3'))]
results_table = wandb.Table(columns=["Audio File", "Response"])
for file_name in tqdm(audio_file_list, desc="Processing Audio"):
audio_path = os.path.join(config.audio_dir, file_name)
with torch.no_grad():
speech_token_ids = vq_model.encode(audio_path).to(device)
print(speech_token_ids)
speech_token_ids += text_vocab_size + image_vocab_size
input_ids = text_tokenizer(
['<|start_header_id|>user<|end_header_id|>\n' + config.question +'<eot_id><|start_header_id|>assistant<|end_header_id|>\n'],
return_tensors="pt"
).input_ids.to(device)
input_ids = torch.cat([
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|s2t|>']).to(device),
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soa|>']).to(device),
speech_token_ids,
(torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoa|>']).to(device),
# (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
# input_ids
], dim=1).long()
output_ids = model.mmu_generate(input_ids, max_new_tokens=512, steps=512, block_length=512)
# print(output_ids[:, input_ids.shape[1]:])
text = uni_prompting.text_tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)
print(f"\nFile: {file_name}\nResponse: {text}")
results_table.add_data(file_name, text)
wandb.log({"Speech-to-Text Response": results_table})