import safetensors.torch import torch import sys # Usage: python z_image_convert_original_to_comfy.py output.safetensors diffusion_model*.safetensors cast_to = None if "fp8_e4m3fn" in sys.argv[1]: cast_to = torch.float8_e4m3fn elif "fp16" in sys.argv[1]: cast_to = torch.float16 elif "bf16" in sys.argv[1]: cast_to = torch.bfloat16 replace_keys = {"all_final_layer.2-1.": "final_layer.", "all_x_embedder.2-1.": "x_embedder.", ".attention.to_out.0.bias": ".attention.out.bias", ".attention.norm_k.weight": ".attention.k_norm.weight", ".attention.norm_q.weight": ".attention.q_norm.weight", ".attention.to_out.0.weight": ".attention.out.weight" } out_sd = {} for f in sys.argv[2:]: sd = safetensors.torch.load_file(f) cc = None for k in sd: w = sd[k] if cast_to is not None: w = w.to(cast_to) k_out = k if k_out.endswith(".attention.to_out.0.bias"): continue if k_out.endswith(".attention.to_k.weight"): cc = [w] continue if k_out.endswith(".attention.to_q.weight"): cc = [w] + cc continue if k_out.endswith(".attention.to_v.weight"): cc = cc + [w] w = torch.cat(cc, dim=0) k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") for r, rr in replace_keys.items(): k_out = k_out.replace(r, rr) out_sd[k_out] = w safetensors.torch.save_file(out_sd, sys.argv[1])