recoilme commited on
Commit
a0e4570
·
1 Parent(s): 094392b
samples/unet_384x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 7fc181e5c4fecac605f17c984bcdbab86207414fc65c8581b9c2f9f5cf61aeef
  • Pointer size: 130 Bytes
  • Size of remote file: 58.8 kB

Git LFS Details

  • SHA256: e70ade1b92ace6d2ad414916a18d19711eaa39b55a917352f7ff5c8f17287de8
  • Pointer size: 130 Bytes
  • Size of remote file: 59.8 kB
samples/unet_416x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 3bd258a27dbf3f197d4577420ef18ac0ddba8d6287ee833b482b17640f4fc303
  • Pointer size: 130 Bytes
  • Size of remote file: 75 kB

Git LFS Details

  • SHA256: 59c7783a6f42494622b17a31a51dfc946ef2ae16b59a0459a46dec3e21f5ab96
  • Pointer size: 130 Bytes
  • Size of remote file: 73.2 kB
samples/unet_448x768_0.jpg CHANGED

Git LFS Details

  • SHA256: e8cd57d3e01196c19187d58ad3aae5219abba5835d8eb846077196384b5284eb
  • Pointer size: 130 Bytes
  • Size of remote file: 75.5 kB

Git LFS Details

  • SHA256: fe336f2dc48eb536aa77a9bfab6e9f0da82e0f0cc67a91eaa9df13136a3f91bb
  • Pointer size: 130 Bytes
  • Size of remote file: 71.3 kB
samples/unet_480x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 30db9e128bde481d9e67177bacac20e742692d5e5ecc81f9348cea1c3de0137f
  • Pointer size: 130 Bytes
  • Size of remote file: 62 kB

Git LFS Details

  • SHA256: e3e01a060739be15480e0fe1f1ee6f8ae8169c1f2577df9377b8f994f5dde23b
  • Pointer size: 130 Bytes
  • Size of remote file: 64.8 kB
samples/unet_512x768_0.jpg CHANGED

Git LFS Details

  • SHA256: b9feb8d9f67e98564c39ae407844042775b150cc59fc94b4c3d7f762511eb9f7
  • Pointer size: 130 Bytes
  • Size of remote file: 91.2 kB

Git LFS Details

  • SHA256: 3a6de955c86b74e60827cac0f8eaa0abe08f874b4474dc917db81140e0a919a9
  • Pointer size: 130 Bytes
  • Size of remote file: 97.7 kB
samples/unet_544x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 38aa3bbec8c024175c19cc255a49d0e6a27233ae4b292deb514d61d3828c4ff9
  • Pointer size: 130 Bytes
  • Size of remote file: 70 kB

Git LFS Details

  • SHA256: c68a811f698ccf47b94cef94d9fc67fe0c1e638927fd19598c7a00c2e6ddab8a
  • Pointer size: 130 Bytes
  • Size of remote file: 96.6 kB
samples/unet_576x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 8a79e259559ac11675ba2b1a7789f5450e3560404772d1d3563f5e2b2328ea36
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB

Git LFS Details

  • SHA256: d1ace4579dcec3044c46afa68f1ef011f82c9cb45327186bf1786d2dbbfa1dc4
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB
samples/unet_608x768_0.jpg CHANGED

Git LFS Details

  • SHA256: b5ec17905db4eba48052279fe23f2c1b0538b493c892557dc2ccce16159ef7b3
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB

Git LFS Details

  • SHA256: a041599ebda2802fb2a1c8b8ef846f1d75fe7be3d45bf32472ab5f2ab99c5921
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
samples/unet_640x768_0.jpg CHANGED

Git LFS Details

  • SHA256: eb375bad4831c8b4d47b6779ed8c81466cc48ff2cbc67fd2e1ebc9eab414c8fc
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB

Git LFS Details

  • SHA256: 0e06db0136aa148488fa7a303ac05581a2833f84606cf5c0e4ff9aae6fb01fe3
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
samples/unet_672x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 0319604e804d2007681815536f10a36f7df7cdf4402bd8e009f822e6711fd2a7
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB

Git LFS Details

  • SHA256: 83415c9000944f145eb688955e28bbc75a4745f040da456c90769c5feff60db4
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
samples/unet_704x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 5bcaaec2cc322fdc07735fc2c63e6196acc82c90f6d1ceb1c183bdd69d72df30
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB

Git LFS Details

  • SHA256: b08b35604cb7e1b0386ad5e9bd7f97b70516dc178e16eb5a193dc1f5a4557d70
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
samples/unet_736x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 454116f7fcb332d4dbd33cccebc771617f5b249e32bd6b031cff980eb87ab267
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB

Git LFS Details

  • SHA256: 9c7255646a04984c54640016d1aed3d9bcca82868eb8207c57b9e7a16a2110de
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
samples/unet_768x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 5bcc85ecdb19b48fb202d52d340c79ea9b86e615ab7b0b8d27d7153613599708
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB

Git LFS Details

  • SHA256: e5410049ee01bb6a25f98127e4060cdaf142f45760b3f42f66ac479e0c270949
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
samples/unet_768x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 03485c0756694da37f70be1a3d90682945cb3e43774d3d99e69f96fc8f339ab5
  • Pointer size: 130 Bytes
  • Size of remote file: 91.4 kB

Git LFS Details

  • SHA256: d41a63cc782523af8037da1665e9e71901f13dad881aad2deb7a3a45bd9907c4
  • Pointer size: 130 Bytes
  • Size of remote file: 67.9 kB
samples/unet_768x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 0c17130e34823c6014ea30b8d963ad2497c5f97830da621f6cd9a44f12600084
  • Pointer size: 131 Bytes
  • Size of remote file: 177 kB

Git LFS Details

  • SHA256: fe3f4cf09d7f5ef96f2a8103c87c7c0de365813c0ec04c37131bc1c13196a4a0
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
samples/unet_768x480_0.jpg CHANGED

Git LFS Details

  • SHA256: f973947dcac1f3583d729d43d238dd51954146d4f68f0d0e5f58efe64bc53f18
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB

Git LFS Details

  • SHA256: 709aaa93348a10356b790796192916c52fb9c5cd2e7152b364ee601328526541
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/unet_768x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 5a7abd26d74f1720a07c368c133b4f84616bdce6334708d726de7b839aa37ed3
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB

Git LFS Details

  • SHA256: 6a52a4f44ed0198ceaa549112ec38ddb06440f65c94c63c559ec11a9c4636583
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
samples/unet_768x544_0.jpg CHANGED

Git LFS Details

  • SHA256: ccc9eb87637b9f66ea069627a213943d42f4d56f98ccfb565b1363bc75337af8
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB

Git LFS Details

  • SHA256: 86e0a99f8c9c22a6f50e611c321fc386da7cf2d38251dea9ff0882ab291d32ba
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/unet_768x576_0.jpg CHANGED

Git LFS Details

  • SHA256: 0cfce3036726a150977a4cc88e6aec54abf8d1fc3a9f80e2aa2d7f6d51accbc9
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB

Git LFS Details

  • SHA256: 19bc65b0ea1b45e9236b3b16735c8cdc517c3818765978a522764d34a584b146
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
samples/unet_768x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 20834875e7fa508e18f9b6171243b5cb9ea16f66aa24b99f17c3f352046b80a7
  • Pointer size: 130 Bytes
  • Size of remote file: 88 kB

Git LFS Details

  • SHA256: 9a104bc2cc9cfd99cd9355665cc0e8d0f59ab597a307fabb2ccc399135300d5d
  • Pointer size: 130 Bytes
  • Size of remote file: 98.2 kB
samples/unet_768x640_0.jpg CHANGED

Git LFS Details

  • SHA256: cef1f1269c42764684bffef38bf31100585d6dbed81fb04519c53005f94c9bf5
  • Pointer size: 130 Bytes
  • Size of remote file: 97.9 kB

Git LFS Details

  • SHA256: 5caf6717bae8961ff63a4732331320338d67c453b2e5a6066ab9a1dc64c2bbf1
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
samples/unet_768x672_0.jpg CHANGED

Git LFS Details

  • SHA256: 5758217f5baf8a75f7dc9dbe66d4d8f685baf57e29883d7150ef9cb50f76d72b
  • Pointer size: 130 Bytes
  • Size of remote file: 47.7 kB

Git LFS Details

  • SHA256: 20c180d66aa3f95d7209c57c27b687b491ab99aa9617f6d90832a5bb8feebbb6
  • Pointer size: 130 Bytes
  • Size of remote file: 59.1 kB
samples/unet_768x704_0.jpg CHANGED

Git LFS Details

  • SHA256: f7eff6ac08eb09479efd01ceb49ad991c33d2a6800cf83069265e582f9d8b3ca
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB

Git LFS Details

  • SHA256: 1704f6cf1720873b7942f9e206daca50a4a3b0bc55b405f79a2abf9d928a38fe
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
samples/unet_768x736_0.jpg CHANGED

Git LFS Details

  • SHA256: 5fef210b7f7ca1daca122f53448fbea06da503837d630d79dd94333e9f8ec4aa
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB

Git LFS Details

  • SHA256: db91ee98ea9562006f4974343041095550220afe45d995a4a28e5f6dd6a57422
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
samples/unet_768x768_0.jpg CHANGED

Git LFS Details

  • SHA256: 55d61512a51e993f63f6fb02b0e8aa20d095aa4a00aaa7641bb0540df1489bca
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB

Git LFS Details

  • SHA256: db53ebb7531d20f7e4b8eebb69e81671a3ed57f4611d1dcc3fe9eb350a2bc3ee
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
train-Copy1.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs/datasets/768"
31
+ project = "unet"
32
+ batch_size = 36
33
+ base_learning_rate = 3e-5 #4e-5
34
+ min_learning_rate = 1e-5 #2.7e-5
35
+ num_epochs = 100
36
+ sample_interval_share = 5
37
+ max_length = 192
38
+ use_wandb = True
39
+ use_comet_ml = False
40
+ save_model = True
41
+ use_decay = True
42
+ fbp = False
43
+ optimizer_type = "adam8bit"
44
+ torch_compile = False
45
+ unet_gradient = True
46
+ fixed_seed = False
47
+ shuffle = True
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
49
+ comet_ml_workspace = "recoilme"
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+ #torch.backends.cuda.enable_mem_efficient_sdp(False)
53
+ dtype = torch.float32
54
+ save_barrier = 1.01
55
+ warmup_percent = 0.01
56
+ percentile_clipping = 96 #97
57
+ betta2 = 0.999
58
+ eps = 1e-7
59
+ clip_grad_norm = 1.0
60
+ limit = 0
61
+ checkpoints_folder = ""
62
+ mixed_precision = "no"
63
+ gradient_accumulation_steps = 1
64
+
65
+ accelerator = Accelerator(
66
+ mixed_precision=mixed_precision,
67
+ gradient_accumulation_steps=gradient_accumulation_steps
68
+ )
69
+ device = accelerator.device
70
+
71
+ # Параметры для диффузии
72
+ n_diffusion_steps = 40
73
+ samples_to_generate = 12
74
+ guidance_scale = 4
75
+
76
+ # Папки для сохранения результатов
77
+ generated_folder = "samples"
78
+ os.makedirs(generated_folder, exist_ok=True)
79
+
80
+ # Настройка seed
81
+ current_date = datetime.now()
82
+ seed = int(current_date.strftime("%Y%m%d"))
83
+ if fixed_seed:
84
+ torch.manual_seed(seed)
85
+ np.random.seed(seed)
86
+ random.seed(seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.manual_seed_all(seed)
89
+
90
+ # --------------------------- Параметры LoRA ---------------------------
91
+ lora_name = ""
92
+ lora_rank = 32
93
+ lora_alpha = 64
94
+
95
+ print("init")
96
+
97
+ # --------------------------- Инициализация WandB ---------------------------
98
+ if accelerator.is_main_process:
99
+ if use_wandb:
100
+ wandb.init(project=project+lora_name, config={
101
+ "batch_size": batch_size,
102
+ "base_learning_rate": base_learning_rate,
103
+ "num_epochs": num_epochs,
104
+ "optimizer_type": optimizer_type,
105
+ })
106
+ if use_comet_ml:
107
+ from comet_ml import Experiment
108
+ comet_experiment = Experiment(
109
+ api_key=comet_ml_api_key,
110
+ project_name=project,
111
+ workspace=comet_ml_workspace
112
+ )
113
+ hyper_params = {
114
+ "batch_size": batch_size,
115
+ "base_learning_rate": base_learning_rate,
116
+ "num_epochs": num_epochs,
117
+ }
118
+ comet_experiment.log_parameters(hyper_params)
119
+
120
+ # Включение Flash Attention 2/SDPA
121
+ torch.backends.cuda.enable_flash_sdp(True)
122
+
123
+ # --------------------------- Загрузка моделей ---------------------------
124
+ vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
125
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
126
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
127
+
128
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
129
+ def encode_texts(texts, max_length=max_length):
130
+ # Если тексты пустые (для unconditional), создаем заглушки
131
+ if texts is None:
132
+ # В случае None возвращаем нули (логика для get_negative_embedding)
133
+ # Но здесь мы обычно ожидаем список строк.
134
+ pass
135
+
136
+ with torch.no_grad():
137
+ if isinstance(texts, str):
138
+ texts = [texts]
139
+
140
+ for i, prompt_item in enumerate(texts):
141
+ messages = [
142
+ {"role": "user", "content": prompt_item},
143
+ ]
144
+ prompt_item = tokenizer.apply_chat_template(
145
+ messages,
146
+ tokenize=False,
147
+ add_generation_prompt=True,
148
+ #enable_thinking=True,
149
+ )
150
+ #print(prompt_item+"\n")
151
+ texts[i] = prompt_item
152
+
153
+ toks = tokenizer(
154
+ texts,
155
+ return_tensors="pt",
156
+ padding="max_length",
157
+ truncation=True,
158
+ max_length=max_length
159
+ ).to(device)
160
+
161
+ outs = text_model(**toks, output_hidden_states=True, return_dict=True)
162
+
163
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
164
+ hidden = outs.hidden_states[-2]
165
+
166
+ # 2. Маска внимания
167
+ attention_mask = toks["attention_mask"]
168
+
169
+ # 3. Пулинг-эмбеддинг (Последний токен)
170
+ sequence_lengths = attention_mask.sum(dim=1) - 1
171
+ batch_size = hidden.shape[0]
172
+ pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
173
+
174
+ #return hidden, attention_mask
175
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
176
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
177
+ pooled_expanded = pooled.unsqueeze(1)
178
+
179
+ # 2. Объединяем последовательность токенов и пулинг-вектор
180
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
181
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
182
+ new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
183
+
184
+ # 3. Обновляем маску внимания для нового токена
185
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
186
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
187
+ new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
188
+
189
+ return new_encoder_hidden_states, new_attention_mask
190
+
191
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
192
+ if shift_factor is None: shift_factor = 0.0
193
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
194
+ if scaling_factor is None: scaling_factor = 1.0
195
+
196
+ from diffusers import FlowMatchEulerDiscreteScheduler
197
+ num_train_timesteps = 1000
198
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=num_train_timesteps)
199
+
200
+ class DistributedResolutionBatchSampler(Sampler):
201
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
202
+ self.dataset = dataset
203
+ self.batch_size = max(1, batch_size // num_replicas)
204
+ self.num_replicas = num_replicas
205
+ self.rank = rank
206
+ self.shuffle = shuffle
207
+ self.drop_last = drop_last
208
+ self.epoch = 0
209
+
210
+ try:
211
+ widths = np.array(dataset["width"])
212
+ heights = np.array(dataset["height"])
213
+ except KeyError:
214
+ widths = np.zeros(len(dataset))
215
+ heights = np.zeros(len(dataset))
216
+
217
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
218
+ self.size_groups = {}
219
+ for w, h in self.size_keys:
220
+ mask = (widths == w) & (heights == h)
221
+ self.size_groups[(w, h)] = np.where(mask)[0]
222
+
223
+ self.group_num_batches = {}
224
+ total_batches = 0
225
+ for size, indices in self.size_groups.items():
226
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
227
+ self.group_num_batches[size] = num_full_batches
228
+ total_batches += num_full_batches
229
+
230
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
231
+
232
+ def __iter__(self):
233
+ if torch.cuda.is_available():
234
+ torch.cuda.empty_cache()
235
+ all_batches = []
236
+ rng = np.random.RandomState(self.epoch)
237
+
238
+ for size, indices in self.size_groups.items():
239
+ indices = indices.copy()
240
+ if self.shuffle:
241
+ rng.shuffle(indices)
242
+ num_full_batches = self.group_num_batches[size]
243
+ if num_full_batches == 0:
244
+ continue
245
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
246
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
247
+ start_idx = self.rank * self.batch_size
248
+ end_idx = start_idx + self.batch_size
249
+ gpu_batches = batches[:, start_idx:end_idx]
250
+ all_batches.extend(gpu_batches)
251
+
252
+ if self.shuffle:
253
+ rng.shuffle(all_batches)
254
+ accelerator.wait_for_everyone()
255
+ return iter(all_batches)
256
+
257
+ def __len__(self):
258
+ return self.num_batches
259
+
260
+ def set_epoch(self, epoch):
261
+ self.epoch = epoch
262
+
263
+ # --- [UPDATED] Функция для фиксированных семплов ---
264
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
265
+ size_groups = defaultdict(list)
266
+ try:
267
+ widths = dataset["width"]
268
+ heights = dataset["height"]
269
+ except KeyError:
270
+ widths = [0] * len(dataset)
271
+ heights = [0] * len(dataset)
272
+ for i, (w, h) in enumerate(zip(widths, heights)):
273
+ size = (w, h)
274
+ size_groups[size].append(i)
275
+
276
+ fixed_samples = {}
277
+ for size, indices in size_groups.items():
278
+ n_samples = min(samples_per_group, len(indices))
279
+ if len(size_groups)==1:
280
+ n_samples = samples_to_generate
281
+ if n_samples == 0:
282
+ continue
283
+ sample_indices = random.sample(indices, n_samples)
284
+ samples_data = [dataset[idx] for idx in sample_indices]
285
+
286
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
287
+ texts = [item["text"] for item in samples_data]
288
+
289
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
290
+ embeddings, masks = encode_texts(texts)
291
+
292
+ fixed_samples[size] = (latents, embeddings, masks, texts)
293
+
294
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
295
+ return fixed_samples
296
+
297
+ if limit > 0:
298
+ dataset = load_from_disk(ds_path).select(range(limit))
299
+ else:
300
+ dataset = load_from_disk(ds_path)
301
+
302
+ # --- [UPDATED] Collate Function ---
303
+ def collate_fn_simple(batch):
304
+ # 1. Латенты (VAE)
305
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
306
+
307
+ # 2. Текст берем сырой из датасета
308
+ raw_texts = [item["text"] for item in batch]
309
+ texts = [
310
+ "" if t.lower().startswith("zero")
311
+ else "" if random.random() < 0.05
312
+ else t[1:].lstrip() if t.startswith(".")
313
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
314
+ for t in raw_texts
315
+ ]
316
+
317
+ # 3. Кодируем на лету
318
+ # Возвращает: hidden (B, L, D), mask (B, L)
319
+ embeddings, attention_mask = encode_texts(texts)
320
+
321
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
322
+ attention_mask = attention_mask.to(dtype=torch.int64)
323
+
324
+ return latents, embeddings, attention_mask
325
+
326
+ batch_sampler = DistributedResolutionBatchSampler(
327
+ dataset=dataset,
328
+ batch_size=batch_size,
329
+ num_replicas=accelerator.num_processes,
330
+ rank=accelerator.process_index,
331
+ shuffle=shuffle
332
+ )
333
+
334
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
335
+ print("Total samples", len(dataloader))
336
+ dataloader = accelerator.prepare(dataloader)
337
+
338
+ start_epoch = 0
339
+ global_step = 0
340
+ total_training_steps = (len(dataloader) * num_epochs)
341
+ world_size = accelerator.state.num_processes
342
+
343
+ # Загрузка UNet
344
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
345
+ if os.path.isdir(latest_checkpoint):
346
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
347
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
348
+ if unet_gradient:
349
+ unet.enable_gradient_checkpointing()
350
+ unet.set_use_memory_efficient_attention_xformers(False)
351
+ try:
352
+ unet.set_attn_processor(AttnProcessor2_0())
353
+ except Exception as e:
354
+ print(f"Ошибка при включении SDPA: {e}")
355
+ unet.set_use_memory_efficient_attention_xformers(True)
356
+ else:
357
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
358
+
359
+ if lora_name:
360
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
361
+ pass
362
+
363
+ # Оптимизатор
364
+ if lora_name:
365
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
366
+ else:
367
+ if fbp:
368
+ trainable_params = list(unet.parameters())
369
+
370
+ def create_optimizer(name, params):
371
+ if name == "adam8bit":
372
+ return bnb.optim.AdamW8bit(
373
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
374
+ percentile_clipping=percentile_clipping
375
+ )
376
+ elif name == "adam":
377
+ return torch.optim.AdamW(
378
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
379
+ )
380
+ elif name == "muon":
381
+ from muon import MuonWithAuxAdam
382
+ trainable_params = [p for p in params if p.requires_grad]
383
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
384
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
385
+
386
+ param_groups = [
387
+ dict(params=hidden_weights, use_muon=True,
388
+ lr=1e-3, weight_decay=1e-4),
389
+ dict(params=hidden_gains_biases, use_muon=False,
390
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
391
+ ]
392
+ optimizer = MuonWithAuxAdam(param_groups)
393
+ from snooc import SnooC
394
+ return SnooC(optimizer)
395
+ else:
396
+ raise ValueError(f"Unknown optimizer: {name}")
397
+
398
+ if fbp:
399
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
400
+ def optimizer_hook(param):
401
+ optimizer_dict[param].step()
402
+ optimizer_dict[param].zero_grad(set_to_none=True)
403
+ for param in trainable_params:
404
+ param.register_post_accumulate_grad_hook(optimizer_hook)
405
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
406
+ else:
407
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
408
+ def lr_schedule(step):
409
+ x = step / (total_training_steps * world_size)
410
+ warmup = warmup_percent
411
+ if not use_decay:
412
+ return base_learning_rate
413
+ if x < warmup:
414
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
415
+ decay_ratio = (x - warmup) / (1 - warmup)
416
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
417
+ (1 + math.cos(math.pi * decay_ratio))
418
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
419
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
420
+
421
+ if torch_compile:
422
+ print("compiling")
423
+ unet = torch.compile(unet)
424
+ print("compiling - ok")
425
+
426
+ # Фиксированные семплы
427
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
428
+
429
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
430
+ def get_negative_embedding(neg_prompt="", batch_size=1):
431
+ if not neg_prompt:
432
+ hidden_dim = 2048
433
+ seq_len = max_length
434
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
435
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
436
+ return empty_emb, empty_mask
437
+
438
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
439
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
440
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
441
+
442
+ return uncond_emb, uncond_mask
443
+
444
+ # Получаем негативные (пустые) условия для валидации
445
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
446
+
447
+ # --- Функция генерации семплов ---
448
+ @torch.compiler.disable()
449
+ @torch.no_grad()
450
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
451
+ uncond_emb, uncond_mask = uncond_data
452
+
453
+ original_model = None
454
+ try:
455
+ if not torch_compile:
456
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
457
+ else:
458
+ original_model = unet.eval()
459
+
460
+ vae.to(device=device).eval()
461
+
462
+ all_generated_images = []
463
+ all_captions = []
464
+
465
+ # Распаковываем 5 элементов (добавились mask)
466
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
467
+ width, height = size
468
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
469
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
470
+ sample_mask = sample_mask.to(device=device)
471
+
472
+ latents = torch.randn(
473
+ sample_latents.shape,
474
+ device=device,
475
+ dtype=sample_latents.dtype,
476
+ generator=torch.Generator(device=device).manual_seed(seed)
477
+ )
478
+
479
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
480
+
481
+ for t in scheduler.timesteps:
482
+ if guidance_scale != 1:
483
+ latent_model_input = torch.cat([latents, latents], dim=0)
484
+
485
+ # Подготовка батчей для CFG (Negative + Positive)
486
+ # 1. Embeddings
487
+ curr_batch_size = sample_text_embeddings.shape[0]
488
+ seq_len = sample_text_embeddings.shape[1]
489
+ hidden_dim = sample_text_embeddings.shape[2]
490
+
491
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
492
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
493
+
494
+ # 2. Masks
495
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
496
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
497
+
498
+ else:
499
+ latent_model_input = latents
500
+ text_embeddings_batch = sample_text_embeddings
501
+ attention_mask_batch = sample_mask
502
+
503
+ # Предсказание с передачей всех условий
504
+ model_out = original_model(
505
+ latent_model_input,
506
+ t,
507
+ encoder_hidden_states=text_embeddings_batch,
508
+ encoder_attention_mask=attention_mask_batch,
509
+ )
510
+ flow = getattr(model_out, "sample", model_out)
511
+
512
+ if guidance_scale != 1:
513
+ flow_uncond, flow_cond = flow.chunk(2)
514
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
515
+
516
+ latents = scheduler.step(flow, t, latents).prev_sample
517
+
518
+ current_latents = latents
519
+
520
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
521
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
522
+ decoded_fp32 = decoded.to(torch.float32)
523
+
524
+ for img_idx, img_tensor in enumerate(decoded_fp32):
525
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
526
+ img = img.transpose(1, 2, 0)
527
+
528
+ if np.isnan(img).any():
529
+ print("NaNs found, saving stopped! Step:", step)
530
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
531
+
532
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
533
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
534
+ max_w_overall = max(255, max_w_overall)
535
+ max_h_overall = max(255, max_h_overall)
536
+
537
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
538
+ all_generated_images.append(padded_img)
539
+
540
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
541
+ all_captions.append(caption_text)
542
+
543
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
544
+ pil_img.save(sample_path, "JPEG", quality=96)
545
+
546
+ if use_wandb and accelerator.is_main_process:
547
+ wandb_images = [
548
+ wandb.Image(img, caption=f"{all_captions[i]}")
549
+ for i, img in enumerate(all_generated_images)
550
+ ]
551
+ wandb.log({"generated_images": wandb_images})
552
+ if use_comet_ml and accelerator.is_main_process:
553
+ for i, img in enumerate(all_generated_images):
554
+ comet_experiment.log_image(
555
+ image_data=img,
556
+ name=f"step_{step}_img_{i}",
557
+ step=step,
558
+ metadata={"caption": all_captions[i]}
559
+ )
560
+ finally:
561
+ vae.to("cpu")
562
+ torch.cuda.empty_cache()
563
+ gc.collect()
564
+
565
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
566
+ if accelerator.is_main_process:
567
+ if save_model:
568
+ print("Генерация сэмплов до старта обучения...")
569
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
570
+ accelerator.wait_for_everyone()
571
+
572
+ def save_checkpoint(unet, variant=""):
573
+ if accelerator.is_main_process:
574
+ if lora_name:
575
+ save_lora_checkpoint(unet)
576
+ else:
577
+ model_to_save = None
578
+ if not torch_compile:
579
+ model_to_save = accelerator.unwrap_model(unet)
580
+ else:
581
+ model_to_save = unet
582
+
583
+ if variant != "":
584
+ model_to_save.to(dtype=torch.float16).save_pretrained(
585
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
586
+ )
587
+ else:
588
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
589
+
590
+ unet = unet.to(dtype=dtype)
591
+
592
+ # --------------------------- Тренировочный цикл ---------------------------
593
+ if accelerator.is_main_process:
594
+ print(f"Total steps per GPU: {total_training_steps}")
595
+
596
+ epoch_loss_points = []
597
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
598
+
599
+ steps_per_epoch = len(dataloader)
600
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
601
+ min_loss = 2.
602
+
603
+ for epoch in range(start_epoch, start_epoch + num_epochs):
604
+ batch_losses = []
605
+ batch_grads = []
606
+ batch_sampler.set_epoch(epoch)
607
+ accelerator.wait_for_everyone()
608
+ unet.train()
609
+
610
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
611
+ with accelerator.accumulate(unet):
612
+ if save_model == False and step == 5 :
613
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
614
+ print(f"Шаг {step}: {used_gb:.2f} GB")
615
+
616
+ # шум
617
+ noise = torch.randn_like(latents, dtype=latents.dtype)
618
+ # берём t из [0, 1]
619
+ t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
620
+ # интерполяция между x0 и шумом
621
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
622
+ # делаем integer timesteps для UNet
623
+ timesteps = (t * scheduler.config.num_train_timesteps).long()
624
+
625
+ # --- Вызов UNet с маской ---
626
+ model_pred = unet(
627
+ noisy_latents,
628
+ timesteps,
629
+ encoder_hidden_states=embeddings,
630
+ encoder_attention_mask=attention_mask
631
+ ).sample
632
+
633
+ target = noise - latents
634
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
635
+ batch_losses.append(mse_loss.detach().item())
636
+
637
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
638
+ accelerator.wait_for_everyone()
639
+
640
+ accelerator.backward(mse_loss)
641
+
642
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
643
+ accelerator.wait_for_everyone()
644
+
645
+ grad = 0.0
646
+ if not fbp:
647
+ if accelerator.sync_gradients:
648
+ #with torch.amp.autocast('cuda', enabled=False):
649
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
650
+ grad = float(grad_val)
651
+ optimizer.step()
652
+ lr_scheduler.step()
653
+ optimizer.zero_grad(set_to_none=True)
654
+
655
+ if accelerator.sync_gradients:
656
+ global_step += 1
657
+ progress_bar.update(1)
658
+ if accelerator.is_main_process:
659
+ if fbp:
660
+ current_lr = base_learning_rate
661
+ else:
662
+ current_lr = lr_scheduler.get_last_lr()[0]
663
+ batch_grads.append(grad)
664
+
665
+ log_data = {}
666
+ log_data["loss"] = mse_loss.detach().item()
667
+ log_data["lr"] = current_lr
668
+ log_data["grad"] = grad
669
+ if accelerator.sync_gradients:
670
+ if use_wandb:
671
+ wandb.log(log_data, step=global_step)
672
+ if use_comet_ml:
673
+ comet_experiment.log_metrics(log_data, step=global_step)
674
+
675
+ if global_step % sample_interval == 0:
676
+ # Передаем tuple (emb, mask) для негатива
677
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
678
+ last_n = sample_interval
679
+
680
+ if save_model:
681
+ has_losses = len(batch_losses) > 0
682
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
683
+ last_loss = batch_losses[-1] if has_losses else 0.0
684
+ max_loss = max(avg_sample_loss, last_loss)
685
+ should_save = max_loss < min_loss * save_barrier
686
+ print(
687
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
688
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
689
+ )
690
+ # 6. Сохранение и обновление
691
+ if should_save:
692
+ min_loss = max_loss
693
+ save_checkpoint(unet)
694
+
695
+ if accelerator.is_main_process:
696
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
697
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
698
+
699
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
700
+ log_data_ep = {
701
+ "epoch_loss": avg_epoch_loss,
702
+ "epoch_grad": avg_epoch_grad,
703
+ "epoch": epoch + 1,
704
+ }
705
+ if use_wandb:
706
+ wandb.log(log_data_ep)
707
+ if use_comet_ml:
708
+ comet_experiment.log_metrics(log_data_ep)
709
+
710
+ if accelerator.is_main_process:
711
+ print("Обучение завершено! Сохраняем финальную модель...")
712
+ if save_model:
713
+ save_checkpoint(unet,"fp16")
714
+ if use_comet_ml:
715
+ comet_experiment.end()
716
+ accelerator.free_memory()
717
+ if torch.distributed.is_initialized():
718
+ torch.distributed.destroy_process_group()
719
+
720
+ print("Готово!")
train.py CHANGED
@@ -30,9 +30,9 @@ from transformers import AutoTokenizer, AutoModel
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
  batch_size = 36
33
- base_learning_rate = 3e-5 #4e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
- num_epochs = 100
36
  sample_interval_share = 5
37
  max_length = 192
38
  use_wandb = True
@@ -94,6 +94,44 @@ lora_alpha = 64
94
 
95
  print("init")
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # --------------------------- Инициализация WandB ---------------------------
98
  if accelerator.is_main_process:
99
  if use_wandb:
@@ -634,10 +672,20 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
634
  mse_loss = F.mse_loss(model_pred.float(), target.float())
635
  batch_losses.append(mse_loss.detach().item())
636
 
 
 
 
 
 
 
 
 
 
 
637
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
638
  accelerator.wait_for_everyone()
639
 
640
- accelerator.backward(mse_loss)
641
 
642
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
643
  accelerator.wait_for_everyone()
@@ -666,6 +714,9 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
666
  log_data["loss"] = mse_loss.detach().item()
667
  log_data["lr"] = current_lr
668
  log_data["grad"] = grad
 
 
 
669
  if accelerator.sync_gradients:
670
  if use_wandb:
671
  wandb.log(log_data, step=global_step)
 
30
  ds_path = "/workspace/sdxs/datasets/768"
31
  project = "unet"
32
  batch_size = 36
33
+ base_learning_rate = 2.7e-5 #4e-5
34
  min_learning_rate = 1e-5 #2.7e-5
35
+ num_epochs = 80
36
  sample_interval_share = 5
37
  max_length = 192
38
  use_wandb = True
 
94
 
95
  print("init")
96
 
97
+ loss_ratios = {
98
+ "mse": 1.,
99
+ }
100
+ median_coeff_steps = 128
101
+
102
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
103
+ class MedianLossNormalizer:
104
+ def __init__(self, desired_ratios: dict, window_steps: int):
105
+ # нормируем доли на случай, если сумма != 1
106
+ s = sum(desired_ratios.values())
107
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
108
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
109
+ self.window = window_steps
110
+
111
+ def update_and_total(self, losses: dict):
112
+ """
113
+ losses: dict ключ->тензор (значения лоссов)
114
+ Поведение:
115
+ - буферим ABS(l) только для активных (ratio>0) лоссов
116
+ - coeff = ratio / median(abs(loss))
117
+ - total = sum(coeff * loss) по активным лоссам
118
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
119
+ """
120
+ # буферим только активные лоссы
121
+ for k, v in losses.items():
122
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
123
+ self.buffers[k].append(float(v.detach().abs().cpu()))
124
+
125
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
126
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
127
+
128
+ # суммируем только по активным (ratio>0)
129
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
130
+ return total, coeffs, meds
131
+
132
+ # создаём normalizer после определения loss_ratios
133
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
134
+
135
  # --------------------------- Инициализация WandB ---------------------------
136
  if accelerator.is_main_process:
137
  if use_wandb:
 
672
  mse_loss = F.mse_loss(model_pred.float(), target.float())
673
  batch_losses.append(mse_loss.detach().item())
674
 
675
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
676
+ accelerator.wait_for_everyone()
677
+
678
+ losses_dict = {}
679
+ losses_dict["mse"] = mse_loss
680
+
681
+ # === Нормализация всех лоссов ===
682
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
683
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
684
+
685
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
686
  accelerator.wait_for_everyone()
687
 
688
+ accelerator.backward(total_loss)
689
 
690
  if (global_step % 100 == 0) or (global_step % sample_interval == 0):
691
  accelerator.wait_for_everyone()
 
714
  log_data["loss"] = mse_loss.detach().item()
715
  log_data["lr"] = current_lr
716
  log_data["grad"] = grad
717
+ log_data["loss_total"] = float(total_loss.item())
718
+ for k, c in coeffs.items():
719
+ log_data[f"coeff_{k}"] = float(c)
720
  if accelerator.sync_gradients:
721
  if use_wandb:
722
  wandb.log(log_data, step=global_step)
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7bdbcdc27f48b45eb0c5d1ea0f316987e13e86df63910bb9c4a773c98c14fffc
3
  size 7444321360
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f2b0091ca3b914bcbd7a8aa359ca00e70aab6b02982ee041535c83bcb69574a
3
  size 7444321360