import argparse import torch from transformers import AutoTokenizer from safetensors.torch import load_file import os from typing import Dict, Any, Type, Set, Optional # Modify this to match the path of your 'models.py' from models import OMadaModelLM, MMadaModelLM def merge_models_safetensors( hf_model_name: str, local_model_directory: str, output_dir: str, alpha: float = 0.5, merge_vocab: bool = True, hf_for_common: bool = False, task_vector_scale: Optional[float] = None, ): """ Merges a Hugging Face model with a local sharded safetensors checkpoint. Args: hf_model_name (str): The name or path of the base model from Hugging Face. local_model_directory (str): The directory path containing the local sharded safetensors files. output_dir (str): The directory to save the merged model and tokenizer. alpha (float): The weighting factor for the Hugging Face model. merge_vocab (bool): If True, merges common token embeddings. hf_for_common (bool): If True, uses HF model weights for common token embeddings. """ print("--- Starting model merge process ---") print(f"Base Hugging Face Model: {hf_model_name}") print(f"Local Model Directory (Safetensors shards): {local_model_directory}") print(f"Merge Alpha: {alpha}") print(f"Merge Vocabulary (for common tokens): {merge_vocab}") print(f"Use HF for common tokens: {hf_for_common}") if task_vector_scale is not None: print(f"Task vector scale (overlapping parameters): {task_vector_scale}") torch_dtype = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, "auto": torch.float16, }.get(str("bfloat16").lower(), torch.bfloat16) model_class_hf: Type[torch.nn.Module] = MMadaModelLM model_class_local: Type[torch.nn.Module] = OMadaModelLM try: hf_model = model_class_hf.from_pretrained( hf_model_name, trust_remote_code=True, torch_dtype=torch_dtype, ).eval() hf_state_dict = hf_model.state_dict() print(f"Successfully loaded Hugging Face model '{hf_model_name}'.") except Exception as e: print(f"Error loading Hugging Face model '{hf_model_name}': {e}") return try: local_model = model_class_local.from_pretrained( local_model_directory, trust_remote_code=True, torch_dtype=torch_dtype, config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" ).eval() local_state_dict = local_model.state_dict() print(f"Successfully loaded local model from directory: '{local_model_directory}'.") except Exception as e: print(f"Error loading local model from directory '{local_model_directory}': {e}") print("Please ensure the directory contains all safetensors shards and a valid config.json.") return merged_state_dict = {} all_keys = set(hf_state_dict.keys()).union(set(local_state_dict.keys())) print("Merging model weights...") for key in all_keys: if key in ["model.transformer.wte.weight", "model.transformer.ff_out.weight"]: # --- Merge logic for embedding and final projection layers --- # 1. Use HF for common tokens, local for expanded portion (--hf-for-common) if hf_for_common: hf_weights = hf_state_dict[key] local_weights = local_state_dict[key] if hf_weights.shape[0] < local_weights.shape[0]: print(f"Detected embedding size mismatch for key '{key}'. Merging by tensor indices.") merged_weights = torch.zeros_like(local_weights) common_size = hf_weights.shape[0] # Copy HF weights for the common vocabulary portion (no interpolation) merged_weights[:common_size, :] = hf_weights # Copy local weights for the expanded embedding space merged_weights[common_size:, :] = local_weights[common_size:, :] merged_state_dict[key] = merged_weights print(f"Successfully merged {common_size} old and {local_weights.shape[0] - common_size} new embeddings.") continue else: print(f"Embedding sizes are identical for key '{key}'. Proceeding with standard vocab merge logic.") # 2. Use local model weights for all tokens (--no-merge-vocab) elif not merge_vocab: print(f"Vocab merge disabled. Using local model weights for key '{key}'.") merged_state_dict[key] = local_state_dict[key] continue # 3. Average common tokens, use local for new tokens (default) print(f"Merging common tokens for key '{key}'.") try: hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name, trust_remote_code=True) local_tokenizer = AutoTokenizer.from_pretrained(local_model_directory, trust_remote_code=True) except Exception as e: print(f"Error loading tokenizers for special merge handling: {e}") if key in hf_state_dict and key in local_state_dict: if hf_state_dict[key].shape == local_state_dict[key].shape: merged_weights = (alpha * hf_state_dict[key]) + ((1 - alpha) * local_state_dict[key]) merged_state_dict[key] = merged_weights else: merged_state_dict[key] = local_state_dict[key] continue hf_vocab_ids: Set[int] = set(hf_tokenizer.get_vocab().values()) local_vocab_ids: Set[int] = set(local_tokenizer.get_vocab().values()) common_vocab_ids = hf_vocab_ids.intersection(local_vocab_ids) new_vocab_ids = local_vocab_ids.difference(hf_vocab_ids) print(f"Found {len(common_vocab_ids)} common tokens and {len(new_vocab_ids)} new tokens.") hf_weights = hf_state_dict[key] local_weights = local_state_dict[key] merged_weights = local_weights.clone() common_indices = torch.tensor(sorted(common_vocab_ids), dtype=torch.long) if task_vector_scale is not None: merged_weights[common_indices] = hf_weights[common_indices] + task_vector_scale * (local_weights[common_indices] - hf_weights[common_indices]) else: merged_weights[common_indices] = ( alpha * hf_weights[common_indices] + (1 - alpha) * local_weights[common_indices] ) new_indices = torch.tensor(sorted(new_vocab_ids), dtype=torch.long) merged_weights[new_indices] = local_weights[new_indices] merged_state_dict[key] = merged_weights elif key in hf_state_dict and key in local_state_dict: if hf_state_dict[key].shape == local_state_dict[key].shape: if task_vector_scale is not None: merged_weights = hf_state_dict[key] + task_vector_scale * (local_state_dict[key] - hf_state_dict[key]) else: merged_weights = (alpha * hf_state_dict[key]) + ((1 - alpha) * local_state_dict[key]) merged_state_dict[key] = merged_weights else: hf_shape = hf_state_dict[key].shape local_shape = local_state_dict[key].shape if len(hf_shape) > 0 and hf_shape[0] <= local_shape[0] and hf_shape[1:] == local_shape[1:]: print(f"Key '{key}' has expanded leading dimension. Merging common portion and keeping local extras.") merged_weights = local_state_dict[key].clone() common_size = hf_shape[0] merged_weights[:common_size] = hf_state_dict[key][:common_size] merged_state_dict[key] = merged_weights else: print(f"Warning: Key '{key}' has mismatched shapes (hf={hf_shape}, local={local_shape}). Using local model's weights.") merged_state_dict[key] = local_state_dict[key] elif key in hf_state_dict: merged_state_dict[key] = hf_state_dict[key] elif key in local_state_dict: merged_state_dict[key] = local_state_dict[key] print(f"Successfully processed {len(all_keys)} parameters for merging.") print("Loading merged weights into a new model instance...") try: new_model = model_class_local.from_pretrained( local_model_directory, trust_remote_code=True, torch_dtype=torch_dtype, config="/home/work/AIDAS/ckpts/omada/omada-training-stage1/config.json" ).eval() new_model.load_state_dict(merged_state_dict, strict=False) print("Merged weights successfully loaded into new model.") except Exception as e: print(f"Error loading merged state dict into new model: {e}") return try: os.makedirs(output_dir, exist_ok=True) new_model.save_pretrained(output_dir, safe_serialization=True) local_tokenizer = AutoTokenizer.from_pretrained(local_model_directory, trust_remote_code=True) local_tokenizer.save_pretrained(output_dir) print(f"--- Model successfully merged and saved to '{output_dir}' ---") except Exception as e: print(f"Error saving the merged model and tokenizer: {e}") # --- Example Usage --- if __name__ == '__main__': parser = argparse.ArgumentParser(description="Merge a Hugging Face model with a local checkpoint.") parser.add_argument( "--alpha", type=float, default=999, help="The weighting factor for the Hugging Face model. 0.5 for 50/50 average." ) # Existing '--no-merge-vocab' flag parser.add_argument( "--no-merge-vocab", action="store_false", dest="merge_vocab", help="Do not merge overlapping token embeddings; use local model's embeddings as-is." ) # New '--hf-for-common' flag added parser.add_argument( "--hf-for-common", action="store_true", help="Use HF model weights for common tokens and local weights for expanded space." ) parser.add_argument( "--task-vector-scale", type=float, default=None, help="Scale factor for the task vector (local - HF) over overlapping parameters.", ) args = parser.parse_args() MERGE_ALPHA = args.alpha MERGE_VOCAB = args.merge_vocab HF_FOR_COMMON = args.hf_for_common TASK_VECTOR_SCALE = args.task_vector_scale if not MERGE_VOCAB and HF_FOR_COMMON: print("Error: You cannot use --no-merge-vocab and --hf-for-common at the same time.") exit(1) HF_MODEL_NAME = "Gen-Verse/MMaDA-8B-MixCoT" LOCAL_MODEL_DIRECTORY = "/home/work/AIDAS/ckpts/omada/omada-training-stage1_7th/checkpoint-315000/unwrapped_model/" # Dynamic output directory based on flags if HF_FOR_COMMON: SUB_DIR = "hf_common_merge" elif not MERGE_VOCAB: SUB_DIR = "no_vocab_merge" else: SUB_DIR = "average_merge" scale_suffix = "" if TASK_VECTOR_SCALE is not None: scale_str = str(TASK_VECTOR_SCALE).replace(".", "p") scale_suffix = f"_scale_{scale_str}" OUTPUT_DIRECTORY = f"/home/work/AIDAS/ckpts/merged_model/{SUB_DIR}_alpha_{MERGE_ALPHA}{scale_suffix}" merge_models_safetensors( hf_model_name=HF_MODEL_NAME, local_model_directory=LOCAL_MODEL_DIRECTORY, output_dir=OUTPUT_DIRECTORY, alpha=MERGE_ALPHA, merge_vocab=MERGE_VOCAB, hf_for_common=HF_FOR_COMMON, task_vector_scale=TASK_VECTOR_SCALE, )