File size: 1,597 Bytes
4f1e20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import safetensors.torch
import torch
import sys

# Usage: python z_image_convert_original_to_comfy.py output.safetensors diffusion_model*.safetensors

cast_to = None
if "fp8_e4m3fn" in sys.argv[1]:
    cast_to = torch.float8_e4m3fn
elif "fp16" in sys.argv[1]:
    cast_to = torch.float16
elif "bf16" in sys.argv[1]:
    cast_to = torch.bfloat16

replace_keys = {"all_final_layer.2-1.": "final_layer.",
                "all_x_embedder.2-1.": "x_embedder.",
                ".attention.to_out.0.bias": ".attention.out.bias",
                ".attention.norm_k.weight": ".attention.k_norm.weight",
                ".attention.norm_q.weight": ".attention.q_norm.weight",
                ".attention.to_out.0.weight": ".attention.out.weight"
                }

out_sd = {}
for f in sys.argv[2:]:
    sd = safetensors.torch.load_file(f)
    cc = None
    for k in sd:
        w = sd[k]

        if cast_to is not None:
            w = w.to(cast_to)
        k_out = k
        if k_out.endswith(".attention.to_out.0.bias"):
            continue
        if k_out.endswith(".attention.to_k.weight"):
            cc = [w]
            continue
        if k_out.endswith(".attention.to_q.weight"):
            cc = [w] + cc
            continue
        if k_out.endswith(".attention.to_v.weight"):
            cc = cc + [w]
            w = torch.cat(cc, dim=0)
            k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")

        for r, rr in replace_keys.items():
            k_out = k_out.replace(r, rr)
        out_sd[k_out] = w



safetensors.torch.save_file(out_sd, sys.argv[1])