Spaces:
Runtime error
Runtime error
| ''' | |
| Use pretrained instruct pix2pix model but add additional channels for reference modification | |
| ''' | |
| import torch | |
| from .diffusion import DDIMLDMTextTraining | |
| from einops import rearrange | |
| from modules.video_unet_temporal.resnet import InflatedConv3d | |
| from safetensors.torch import load_file | |
| import torch.nn.functional as F | |
| from torch import nn | |
| import cv2 | |
| from torch.hub import download_url_to_file | |
| class MLP(nn.Module): | |
| def __init__(self): | |
| super(MLP, self).__init__() | |
| self.fc1 = nn.Linear(3072, 4096) | |
| self.fc2 = nn.Linear(4096, 4096) | |
| self.fc3 = nn.Linear(4096, 4096) | |
| self.fc4 = nn.Linear(4096, 2304) | |
| self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) # 设置Leaky ReLU的负斜率 | |
| def forward(self, x): | |
| x = self.leaky_relu(self.fc1(x)) | |
| x = self.leaky_relu(self.fc2(x)) | |
| x = self.leaky_relu(self.fc3(x)) | |
| x = self.fc4(x) | |
| return x | |
| # class CombineMLP(nn.Module): | |
| # def __init__(self, input_dim=128, output_dim=64, hidden_dim=128): | |
| # """ | |
| # 构造一个 5 层 MLP 网络。 | |
| # :param input_dim: 输入的特征维度,默认 128 | |
| # :param output_dim: 输出的特征维度,默认 64 | |
| # :param hidden_dim: 隐藏层维度,默认 128 | |
| # """ | |
| # super(CombineMLP, self).__init__() | |
| # # 定义 5 层 MLP | |
| # self.fc1 = nn.Linear(input_dim, hidden_dim) #() | |
| # self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
| # self.fc3 = nn.Linear(hidden_dim, hidden_dim) | |
| # self.fc4 = nn.Linear(hidden_dim, hidden_dim) | |
| # self.fc5 = nn.Linear(hidden_dim, output_dim) # 最后一层映射到 64 | |
| # # 定义激活函数 | |
| # # self.activation = nn.ReLU() | |
| # self.activation = nn.LeakyReLU(negative_slope=0.01) # 默认负斜率为 0.01 | |
| # def forward(self, x1, x2): | |
| # """ | |
| # 前向传播,支持两个输入 x1 和 x2 | |
| # :param x1: 第一个输入,形状 (B, 64) | |
| # :param x2: 第二个输入,形状 (B, 64) | |
| # :return: 输出特征,形状 (B, 64) | |
| # """ | |
| # # 将两个输入拼接在一起 | |
| # x = torch.cat([x1, x2], dim=-1) # 拼接后形状为 (B, 128) | |
| # # 依次通过 5 层 MLP 和激活函数 | |
| # x = self.activation(self.fc1(x)) | |
| # x = self.activation(self.fc2(x)) | |
| # x = self.activation(self.fc3(x)) | |
| # x = self.activation(self.fc4(x)) | |
| # x = self.fc5(x) # 最后一层不使用激活函数(根据需求) | |
| # return x | |
| class CombineMLP(nn.Module): | |
| def __init__(self, input_dim=4*64*64*2, output_dim=4*64*64, hidden_dim=128, num_layers=5): | |
| """ | |
| 构造一个 5 层 MLP 网络。 | |
| :param input_dim: 输入的特征维度,默认 128 | |
| :param output_dim: 输出的特征维度,默认 64 | |
| :param hidden_dim: 隐藏层维度,默认 128 | |
| """ | |
| super(CombineMLP, self).__init__() | |
| # 创建多个隐藏层 | |
| layers = [] | |
| for i in range(num_layers - 1): # 生成 num_layers-1 个隐藏层 | |
| layers.append(nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim)) | |
| layers.append(nn.ReLU()) | |
| # 输出层 | |
| layers.append(nn.Linear(hidden_dim, output_dim)) | |
| # 将层组合成一个模块 | |
| self.mlp = nn.Sequential(*layers) | |
| def forward(self, x1, x2): | |
| """ | |
| 前向传播,支持两个输入 x1 和 x2 | |
| :param x1: 第一个输入,形状 (1,16,4,64,64) | |
| :param x2: 第二个输入,形状 (1,16,4,64,64) | |
| :return: 输出特征,形状 (1,16,4,64,64) | |
| """ | |
| # import pdb; pdb.set_trace() | |
| # 将两个输入拼接在一起 | |
| x = torch.cat([x1, x2], dim=2) # 拼接后形状为 (1,16,8,64,64) | |
| x = torch.flatten(x, start_dim=2) # Flatten to shape (batch_size, 16, 8*64*64) | |
| x = self.mlp(x) # Apply MLP 1,16,16384 | |
| x = x.reshape(x.size(0), x.size(1), 4, 64, 64) # Reshape back to (1, 16, 4, 64, 64) | |
| return x | |
| class HDRCtrlModeltmp(nn.Module): | |
| def __init__(self): | |
| super(HDRCtrlModel, self).__init__() | |
| # 定义卷积层 | |
| self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=4, padding=1) | |
| self.conv_layer2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1) | |
| # 定义 MLP 模型 | |
| self.mlp = MLP() | |
| def decompose_hdr(self, hdr_latents): | |
| batch_size, channels, height, width = hdr_latents.shape | |
| device = hdr_latents.device # 获取设备信息 | |
| # 生成 4×4 掩码 (batch_size, 1, 4, 4) | |
| mask_small = torch.rand(batch_size, 1, 4, 4, device=device) # 从均匀分布生成随机掩码 | |
| # 将掩码调整为与输入相同的大小 (batch_size, 1, height, width) | |
| mask = torch.nn.functional.interpolate(mask_small, size=(height, width), mode='bilinear', align_corners=False) | |
| # 保持连续值,不进行二值化 #! 注意此步操作, 注意可视化 random mask的结果... 首先可以可视化mask, 其次可视化 | |
| mask = mask.expand(-1, channels, -1, -1) # 扩展掩码通道数以匹配 hdr_latents 的形状 | |
| # 应用 mask 生成 L1 和 L2 | |
| hdr_latents_1 = hdr_latents * mask # L1 = 掩码部分 | |
| hdr_latents_2 = hdr_latents * (1 - mask) # L2 = 非掩码部分 | |
| return hdr_latents_1, hdr_latents_2 | |
| def forward(self, hdr_latents): | |
| # import pdb; pdb.set_trace() | |
| # todo: mask get hdr1, hdr2; input hdr_latents(实际上暂时是ldr) | |
| # 输入的形状为 (1, 16, 3, 256, 512),去掉多余的维度 | |
| # import pdb; pdb.set_trace() | |
| hdr_latents = hdr_latents.squeeze(0) # 变成 (16, 3, 256, 512) | |
| batch_size = hdr_latents.shape[0] | |
| # 转换为 NCHW 形式: (batch, channels, height, width) 输入之前numpy2tensor已经permute过了 | |
| # hdr_latents = hdr_latents.permute(0, 3, 1, 2) #! 注意一下to tensor? (如何进行归一化的) 的时候已经 | |
| # 进行卷积操作 | |
| conv_output = self.conv_layer1(hdr_latents) #! 注意更改此处卷积!!! | |
| conv_output = self.conv_layer2(conv_output) # (16, 3, 32, 64) | |
| # 截取前 32 列,得到最终形状 (16, 3, 32, 32) | |
| hdr_latents = conv_output[:, :, :, :32] | |
| # todo: decompose hdr | |
| hdr_latents_1, hdr_latents_2 = self.decompose_hdr(hdr_latents) # [16, 3, 32, 32], [16, 3, 32, 32] | |
| # 将输出展平,准备输入到 MLP 中 | |
| hdr_latents = hdr_latents.reshape(hdr_latents.size(0), -1) # [16, 3072] | |
| hdr_latents_1 = hdr_latents_1.reshape(hdr_latents_1.size(0), -1) # [16, 3072] | |
| hdr_latents_2 = hdr_latents_2.reshape(hdr_latents_2.size(0), -1) | |
| # 传递给 MLP | |
| hdr_latents = self.mlp(hdr_latents) #(16, 2304) 3072 -> 2304 | |
| hdr_latents_1 = self.mlp(hdr_latents_1) | |
| hdr_latents_2 = self.mlp(hdr_latents_2) | |
| # 重新调整输出的形状 | |
| hdr_latents = hdr_latents.reshape(batch_size, 3, 768) # reshape 输出为 (16, 3, 768) | |
| hdr_latents_1 = hdr_latents_1.reshape(batch_size, 3, 768) | |
| hdr_latents_2 = hdr_latents_2.reshape(batch_size, 3, 768) | |
| return hdr_latents, hdr_latents_1, hdr_latents_2 | |
| class HDRCtrlModel(nn.Module): | |
| def __init__(self): | |
| super(HDRCtrlModel, self).__init__() | |
| # 定义卷积层 | |
| self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=4, padding=1) | |
| self.conv_layer2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1) | |
| # 定义 MLP 模型 | |
| self.mlp = MLP() | |
| def decompose_hdr(self, hdr_latents): # hdr_latents: 16,3,32,32 可以可视化一下这部分代码... | |
| batch_size, channels, height, width = hdr_latents.shape | |
| device = hdr_latents.device # 获取设备信息 | |
| # 生成 4×4 掩码 (batch_size, 1, 4, 4) | |
| mask_small = torch.rand(batch_size, 1, 4, 4, device=device) # 从均匀分布生成随机掩码 | |
| threshold = 0.5 # 调节阈值,增加黑色部分比例 | |
| mask_small = (mask_small > threshold).float() | |
| # 将掩码调整为与输入相同的大小 (batch_size, 1, height, width) 16,1,32,32 | |
| mask = torch.nn.functional.interpolate(mask_small, size=(height, width), mode='bilinear', align_corners=False) | |
| # import pdb; pdb.set_trace() | |
| # 保持连续值,不进行二值化 #! 注意此步操作, 注意可视化 random mask的结果... 首先可以可视化mask, 其次可视化 | |
| mask = mask.expand(-1, channels, -1, -1) # 扩展掩码通道数以匹配 hdr_latents 的形状 | |
| # 应用 mask 生成 L1 和 L2 | |
| hdr_latents_1 = hdr_latents * mask # L1 = 掩码部分 | |
| hdr_latents_2 = hdr_latents * (1 - mask) # L2 = 非掩码部分 | |
| return hdr_latents_1, hdr_latents_2 | |
| def blur_image(self, hdr_latents): | |
| # 高斯模糊, 输入 (16,3,256,256) | |
| processed_images = [] | |
| kernel_size = (15, 15) | |
| sigmaX = 10 | |
| # 对每张图像进行处理 | |
| for i in range(hdr_latents.size(0)): # 遍历16张图像 | |
| # 获取第i张图像 | |
| image = hdr_latents[i].permute(1, 2, 0).cpu().numpy() # 将形状变为 (256, 256, 3) | |
| # 进行高斯模糊 | |
| blurred_image = cv2.GaussianBlur(image, kernel_size, sigmaX) | |
| # 将图像缩放到 (32, 32, 3) | |
| resized_image = cv2.resize(blurred_image, (32, 32), interpolation=cv2.INTER_AREA) | |
| # 将处理后的图像从 numpy 数组转换回 tensor | |
| resized_image_tensor = torch.tensor(resized_image, dtype=torch.uint8, device=hdr_latents.device).permute(2, 0, 1) # 转回 (3, 32, 32) | |
| # 将处理后的图像添加到列表中 | |
| processed_images.append(resized_image_tensor) | |
| # 将列表中的所有图像堆叠成一个 tensor | |
| processed_images_tensor = torch.stack(processed_images) # 形状为 (16, 3, 32, 32) | |
| return processed_images_tensor | |
| def normalize_hdr(self, img): | |
| img = img / 255.0 | |
| return img * 2 -1 | |
| def forward(self, hdr_latents): | |
| # import pdb; pdb.set_trace() | |
| # todo: mask get hdr1, hdr2; input hdr_latents(实际上暂时是ldr) | |
| # 输入的形状为 (n, 16, 3, 256, 256),去掉多余的维度 | |
| # import pdb; pdb.set_trace() | |
| # hdr_latents = hdr_latents.squeeze(0) # 变成 (16, 3, 256, 256) | |
| batch_size_ori = hdr_latents.shape[0] | |
| # frame_num = hdr_latents.shape[1] | |
| hdr_latents = rearrange(hdr_latents, 'b f c h w -> (b f) c h w') | |
| batch_size = hdr_latents.shape[0] | |
| # batch_size = hdr_latents.shape[0] | |
| # 转换为 NCHW 形式: (batch, channels, height, width) 输入之前numpy2tensor已经permute过了 | |
| # 高斯模糊 | |
| hdr_latents = self.blur_image(hdr_latents) #(16,3,32,32) 可视化打印一下! | |
| # import pdb; pdb.set_trace() | |
| # todo: decompose hdr | |
| hdr_latents_1, hdr_latents_2 = self.decompose_hdr(hdr_latents) # [16, 3, 32, 32], [16, 3, 32, 32] | |
| # todo: 加一步 normalize /255 -> -1,1 | |
| hdr_latents, hdr_latents_1, hdr_latents_2 = self.normalize_hdr(hdr_latents), self.normalize_hdr(hdr_latents_1), self.normalize_hdr(hdr_latents_2) | |
| # import pdb; pdb.set_trace() | |
| # 将输出展平,准备输入到 MLP 中 | |
| hdr_latents = hdr_latents.reshape(hdr_latents.size(0), -1) # [16, 3072] | |
| hdr_latents_1 = hdr_latents_1.reshape(hdr_latents_1.size(0), -1) # [16, 3072] | |
| hdr_latents_2 = hdr_latents_2.reshape(hdr_latents_2.size(0), -1) | |
| # 传递给 MLP | |
| hdr_latents = self.mlp(hdr_latents) #(16, 2304) 3072 -> 2304 | |
| hdr_latents_1 = self.mlp(hdr_latents_1) | |
| hdr_latents_2 = self.mlp(hdr_latents_2) | |
| # 重新调整输出的形状 | |
| hdr_latents = hdr_latents.reshape(batch_size, 3, 768) # reshape 输出为 (16*n, 3, 768) | |
| hdr_latents_1 = hdr_latents_1.reshape(batch_size, 3, 768) | |
| hdr_latents_2 = hdr_latents_2.reshape(batch_size, 3, 768) | |
| hdr_latents = rearrange(hdr_latents, '(b f) n c -> b f n c', b=batch_size_ori) | |
| hdr_latents_1 = rearrange(hdr_latents_1, '(b f) n c -> b f n c', b=batch_size_ori) | |
| hdr_latents_2 = rearrange(hdr_latents_2, '(b f) n c -> b f n c', b=batch_size_ori) | |
| #! 两个细节: 1. 仅有ldr, 需不需要concat hdr或线性变换 2. mask不同帧不一致 | |
| return hdr_latents, hdr_latents_1, hdr_latents_2 # 3 x (b,16,3,768) | |
| class InstructP2PVideoTrainer(DDIMLDMTextTraining): | |
| def __init__( | |
| self, *args, | |
| cond_image_dropout=0.1, | |
| cond_text_dropout=0.1, | |
| cond_hdr_dropout=0.1, | |
| prompt_type='output_prompt', | |
| text_cfg=7.5, | |
| img_cfg=1.2, | |
| hdr_cfg=7.5, | |
| hdr_rate=0.1, | |
| ic_condition='bg', | |
| hdr_train=False, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.hdr_train = hdr_train | |
| if self.hdr_train: | |
| self.hdr_encoder = HDRCtrlModel() | |
| self.hdr_encoder = self.hdr_encoder.to(self.unet.device) | |
| self.mlp = CombineMLP() | |
| self.cond_hdr_dropout = cond_hdr_dropout | |
| self.hdr_rate = hdr_rate | |
| self.cond_image_dropout = cond_image_dropout | |
| self.cond_text_dropout = cond_text_dropout | |
| assert ic_condition in ['fg', 'bg'] | |
| assert prompt_type in ['output_prompt', 'edit_prompt', 'mixed_prompt'] | |
| self.prompt_type = prompt_type | |
| self.ic_condition = ic_condition | |
| self.text_cfg = text_cfg | |
| self.img_cfg = img_cfg | |
| self.hdr_cfg = hdr_cfg | |
| #! 开启xformers训练设置 | |
| # self.unet.enable_xformers_memory_efficient_attention() | |
| # self.unet.enable_gradient_checkpointing() | |
| def encode_text(self, text): | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| encoded_text = super().encode_text(text) | |
| return encoded_text | |
| def encode_image_to_latent(self, image): | |
| # with torch.cuda.amp.autocast(dtype=torch.float16): | |
| latent = super().encode_image_to_latent(image) | |
| return latent | |
| # @torch.cuda.amp.autocast(dtype=torch.float16) | |
| def get_prompt(self, batch, mode): | |
| # if mode == 'train': | |
| # if self.prompt_type == 'output_prompt': | |
| # prompt = batch['output_prompt'] | |
| # elif self.prompt_type == 'edit_prompt': # training的时候是edit prompt | |
| # prompt = batch['edit_prompt'] | |
| # elif self.prompt_type == 'mixed_prompt': | |
| # if int(torch.rand(1)) > 0.5: | |
| # prompt = batch['output_prompt'] | |
| # else: | |
| # prompt = batch['edit_prompt'] | |
| # else: | |
| # prompt = batch['output_prompt'] | |
| prompt = batch['text_prompt'] | |
| if not self.hdr_train: #! 如果hdr后续加进来text了, 还是需要? | |
| if torch.rand(1).item() < self.cond_text_dropout: | |
| prompt = 'change the background' | |
| cond_text = self.encode_text(prompt) | |
| if mode == 'train': | |
| if torch.rand(1).item() < self.cond_text_dropout: | |
| cond_text = torch.zeros_like(cond_text) | |
| # import pdb; pdb.set_trace() | |
| return cond_text | |
| # @torch.cuda.amp.autocast(dtype=torch.float16) | |
| def encode_image_to_latent(self, image): | |
| b, f, c, h, w = image.shape | |
| image = rearrange(image, 'b f c h w -> (b f) c h w') | |
| latent = super().encode_image_to_latent(image) | |
| latent = rearrange(latent, '(b f) c h w -> b f c h w', b=b) | |
| return latent | |
| # @torch.cuda.amp.autocast(dtype=torch.float16) | |
| def decode_latent_to_image(self, latent): | |
| b, f, c, h, w = latent.shape | |
| latent = rearrange(latent, 'b f c h w -> (b f) c h w') | |
| image = [] | |
| for latent_ in latent: | |
| image_ = super().decode_latent_to_image(latent_[None]) | |
| image.append(image_.sample) #! 注意一下这里 之前没报过错吗; -> 之前不是一个类 | |
| image = torch.cat(image, dim=0) | |
| # image = super().decode_latent_to_image(latent) | |
| image = rearrange(image, '(b f) c h w -> b f c h w', b=b) | |
| return image | |
| def get_cond_image(self, batch, mode): | |
| # import pdb; pdb.set_trace() | |
| cond_fg_image = batch['fg_video'] # 这边condition 就是 input_video了, 估计是concat或者ctrlnet | |
| cond_fg_image = self.encode_image_to_latent(cond_fg_image) | |
| if self.ic_condition == 'bg': | |
| cond_bg_image = batch['bg_video'] | |
| if torch.all(cond_bg_image == 0): | |
| cond_bg_image = torch.zeros_like(cond_fg_image) #! 背景一定概率为0, 置为0.3 | |
| else: | |
| cond_bg_image = self.encode_image_to_latent(cond_bg_image) | |
| cond_image = torch.cat((cond_fg_image, cond_bg_image), dim=2) #(1,16,8,64,64) | |
| else: | |
| cond_image = cond_fg_image | |
| # test code: 可视化代码 | |
| # from PIL import Image | |
| # Image.fromarray(((batch['input_video'] + 1) / 2 * 255).byte()[0,0].permute(1,2,0).cpu().numpy()).save('img1.png') | |
| # ip2p does not scale cond image, so we unscale the cond image | |
| # cond_image = self.encode_image_to_latent(cond_image) / self.scale_factor # 额 就是一个vae encode,没有缩放;这边不进行缩放吗? 啥意思呢 | |
| if mode == 'train': | |
| # if int(torch.rand(1)) < self.cond_image_dropout: # 0.1的概率随机初始化, 应该是为了保障一个鲁棒性 难怪有的时候是全0, 不是代码的bug #! 艹 bug, 这么久才发现.... | |
| if torch.rand(1).item() < self.cond_image_dropout: | |
| cond_image = torch.zeros_like(cond_image) | |
| return cond_image | |
| def get_diffused_image(self, batch, mode): | |
| # import pdb; pdb.set_trace() | |
| x = batch['tgt_video'] # 这边编辑的时候, 具体加噪和去噪的gt, 整个这套流程都是以编辑后, 即edited video作为输入 | |
| # from PIL import Image | |
| # Image.fromarray(((batch['edited_video'] + 1) / 2 * 255).byte()[0,0].permute(1,2,0).cpu().numpy()).save('img2.png') | |
| b, *_ = x.shape | |
| x = self.encode_image_to_latent(x) # (1, 16, 4, 32, 32), 经过了vae encode | |
| eps = torch.randn_like(x) | |
| if mode == 'train': | |
| t = torch.randint(0, self.num_timesteps, (b,), device=x.device).long() | |
| else: | |
| t = torch.full((b,), self.num_timesteps-1, device=x.device, dtype=torch.long) | |
| x_t = self.add_noise(x, t, eps) # 加噪t步长 eps表示高斯噪声, 和scheduler的加噪 | |
| if self.prediction_type == 'epsilon': | |
| return x_t, eps, t | |
| else: | |
| return x_t, x, t | |
| def get_hdr_image(self, batch, mode): | |
| x = batch['ldr_video'] # todo (16,3,256,512), float, tensor, device -> (1,16,3,256,256) 注意此时仅有ldr | |
| # import pdb; pdb.set_trace() | |
| hdr_latents, hdr_latents_1, hdr_latents_2 = self.hdr_encoder(x) | |
| if mode == 'train': #! 考虑一下这个开不开, 因为后面要拉consistency loss | |
| if torch.rand(1).item() < self.cond_hdr_dropout: | |
| hdr_latents = torch.zeros_like(hdr_latents) | |
| hdr_latents_1 = torch.zeros_like(hdr_latents_1) | |
| hdr_latents_2 = torch.zeros_like(hdr_latents_2) | |
| return hdr_latents, hdr_latents_1, hdr_latents_2 | |
| # batch中需要加载mask | |
| def get_mask(self, batch, mode, target): | |
| # (1,16,1,512,512) | |
| # import pdb; pdb.set_trace() | |
| mask = batch['fg_mask'] # todo 返回mask (n,16,1,512,512) | |
| bs = mask.shape[0] | |
| target_height, target_width = target.shape[-2:] #(n,16,3,64,64) | |
| mask = rearrange(mask, 'b f c h w -> (b f) c h w') | |
| resized_mask = F.interpolate(mask, size=(target_height, target_width), mode='bilinear', align_corners=False) | |
| # resized_mask = resized_mask.unsqueeze(0) | |
| resized_mask = rearrange(resized_mask, '(b f) c h w -> b f c h w', b=bs) | |
| if target.shape[2] != resized_mask.shape[2]: | |
| resized_mask = resized_mask.expand(-1, -1, target.shape[2], -1, -1) # 匹配目标通道数 | |
| return resized_mask | |
| def process_batch(self, batch, mode): #! 可视化这边的image, 查看问题出在哪了。。。 √, 应该是randn_drop的事 | |
| # import pdb; pdb.set_trace() | |
| cond_image = self.get_cond_image(batch, mode) # 把输入的src image进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数(ip2p本身没乘...) | |
| diffused_image, target, t = self.get_diffused_image(batch, mode) # diffused_image: 经过了vae encode, 和scheduler的加噪,标准的降噪输入 | |
| # target: 这边是epsilon目标, 因此还是拉成epsilon的损失;t: 训练阶段是随机的一个数值, 推理阶段一般都是1000 | |
| prompt = self.get_prompt(batch, mode) | |
| model_kwargs = { | |
| 'encoder_hidden_states': prompt | |
| } | |
| # import pdb; pdb.set_trace() | |
| if self.hdr_train: | |
| hdr_image, hdr_image_1, hdr_image_2 = self.get_hdr_image(batch, mode) #(16,3,768) | |
| fg_mask = self.get_mask(batch, mode, target) # 把原图像前景mask resize到target大小 | |
| model_kwargs = { | |
| 'encoder_hidden_states': {'hdr_latents': hdr_image, 'encoder_hidden_states': prompt, 'hdr_latents_1': hdr_image_1, 'hdr_latents_2': hdr_image_2, 'fg_mask': fg_mask} | |
| } | |
| return { | |
| 'diffused_input': diffused_image, # (1, 16, 4, 64, 64), 经过了vae encode, 和scheduler的加噪 | |
| 'condition': cond_image, # 把输入的src image进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数 (1,16,8,64,64) | |
| 'target': target, # 这个是加到tgt video的高斯噪声 | |
| 't': t, # 0~1000的一个时刻 | |
| 'model_kwargs': model_kwargs, # 这边就是一个text_hidden_states | |
| } | |
| def training_step(self, batch, batch_idx): #! 注意一下仅仅训motion layer | |
| # import pdb; pdb.set_trace() | |
| processed_batch = self.process_batch(batch, mode='train') #(1,16,3,256,256), 读取的序列化图片, 仅仅做了一个归一化操作 | |
| diffused_input = processed_batch['diffused_input'] # (1,16,4,64,64), edit images, 经过了vae encode, 和scheduler的加噪 | |
| condition = processed_batch['condition'] # (1,16,8,64,64) 把输入的src images进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数 | |
| target = processed_batch['target'] # (1,16,4,64,64), target是加入的高斯噪声 | |
| t = processed_batch['t'] # [257], 一个0~1000的随机时刻 | |
| model_kwargs = processed_batch['model_kwargs'] # dict, 仅包含一项: encoder_hidden_states, [1, 77, 768] text_hidden_states | |
| model_input = torch.cat([diffused_input, condition], dim=2) # b, f, c, h, w [1,16,8,32,32] 这边是做的concat, 很多edit文章经典操作, 把两个东西concat起来 | |
| #! 半精度 | |
| # model_input = model_input.float() | |
| # model_kwargs['encoder_hidden_states'] = model_kwargs['encoder_hidden_states'].half() | |
| model_input = rearrange(model_input, 'b f c h w -> b c f h w') # [1,8,16,32,32] | |
| pred = self.unet(model_input, t, **model_kwargs).sample # (1,4,16,64,64) #! | |
| pred = rearrange(pred, 'b c f h w -> b f c h w') # (1,16,4,64,64) #! | |
| if not self.hdr_train: | |
| loss = self.get_loss(pred, target, t) # 0.320 | |
| else: | |
| fg_mask = model_kwargs['encoder_hidden_states']['fg_mask'] | |
| loss = self.get_hdr_loss(fg_mask, pred, target) | |
| ### add consistency loss ### | |
| # todo: 三个相同的model_input, 不同的model_kwargs (注意stack到一起, attn里面的逻辑也得改...) | |
| # if self.hdr_train: | |
| # fg_mask = model_kwargs['encoder_hidden_states']['fg_mask'] | |
| # hdr_latents = model_kwargs['encoder_hidden_states']['hdr_latents_1'] | |
| # hdr_latents_1 = model_kwargs['encoder_hidden_states']['hdr_latents_1'] | |
| # hdr_latents_2 = model_kwargs['encoder_hidden_states']['hdr_latents_2'] | |
| # model_input = torch.cat([diffused_input, condition], dim=2) | |
| # model_input = rearrange(model_input, 'b f c h w -> b c f h w') | |
| # model_input_1 = model_input.clone() | |
| # model_input_2 = model_input.clone() | |
| # model_input_all = torch.cat([model_input, model_input_1, model_input_2], dim=0) | |
| # prompt = model_kwargs['encoder_hidden_states']['encoder_hidden_states'] #(1*n,77,768) | |
| # prompt_all = torch.cat([prompt, prompt, prompt], dim=0) #(3*n,77,768) | |
| # # import pdb; pdb.set_trace() | |
| # model_kwargs['encoder_hidden_states']['encoder_hidden_states'] = prompt_all | |
| # # import pdb; pdb.set_trace() | |
| # hdr_latents_all = torch.cat([hdr_latents, hdr_latents_1, hdr_latents_2], dim=0) #(3*n,16,77,768) | |
| # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_all | |
| # pred_all = self.unet(model_input_all, t, **model_kwargs).sample # (1,4,16,64,64) | |
| # pred_all = rearrange(pred_all, 'b c f h w -> b f c h w') | |
| # pred, pred1, pred2 = pred_all.chunk(3, dim=0) | |
| # loss_ori = self.get_hdr_loss(fg_mask, pred, target) | |
| # # 假设获得了L1, L2 | |
| # # hdr_latents_1 = mask(hdr_latents) # 随机构造一个mask + 逻辑矫正 | |
| # # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_1 | |
| # # pred1 = self.unet(model_input, t, **model_kwargs).sample # get L1下的预测值 (1,16,4,64,64) | |
| # # pred1 = rearrange(pred1, 'b c f h w -> b f c h w') | |
| # # model_input = torch.cat([diffused_input, condition], dim=2) | |
| # # model_input = rearrange(model_input, 'b f c h w -> b c f h w') | |
| # # # hdr_latents_2 = 1-mask(hdr_latents) | |
| # # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_2 | |
| # # pred2 = self.unet(model_input, t, **model_kwargs).sample # get L2下的预测值 | |
| # # pred2 = rearrange(pred2, 'b c f h w -> b f c h w') | |
| # # import pdb; pdb.set_trace() | |
| # pred_combine = self.mlp(pred1, pred2) #! todo: 构造mlp loss 错了!! 搞对一下, 应该需要展平.... | |
| # loss_c = self.get_hdr_loss(fg_mask, pred, pred_combine) | |
| # # loss_c = MSELoss(mask*pred, mask*pred_conbine) # todo: change to函数, 逻辑矫正 | |
| # loss = loss_ori + self.hdr_rate * loss_c # 设一个系数, 好控制变化 | |
| ### end ### | |
| self.log('train_loss', loss, sync_dist=True) | |
| latent_pred = self.predict_x_0_from_x_t(pred, t, diffused_input) # (1,16,4,32,32) | |
| image_pred = self.decode_latent_to_image(latent_pred) # 这边相当于是pred_x0了, (1,16,3,256,256) | |
| drop_out = torch.all(condition == 0).item() | |
| res_dict = { | |
| 'loss': loss, | |
| 'pred': image_pred, | |
| 'drop_out': drop_out, | |
| 'time': t[0].item() | |
| } | |
| return res_dict | |
| def validation_step(self, batch, batch_idx): # 没写好 可以先pass | |
| # pass | |
| # import pdb; pdb.set_trace() | |
| if not self.hdr_train: | |
| from .inference.inference import InferenceIP2PVideo | |
| inf_pipe = InferenceIP2PVideo( | |
| self.unet, | |
| beta_start=self.scheduler.config.beta_start, | |
| beta_end=self.scheduler.config.beta_end, | |
| beta_schedule=self.scheduler.config.beta_schedule, | |
| num_ddim_steps=20 | |
| ) | |
| # import pdb; pdb.set_trace() | |
| processed_batch = self.process_batch(batch, mode='val') | |
| diffused_input = torch.randn_like(processed_batch['diffused_input']) #(1,16,4,64,64) | |
| condition = processed_batch['condition'] # 这边其实留有一个接口给condition (1,16,8,64,64) | |
| img_cond = condition | |
| text_cond = processed_batch['model_kwargs']['encoder_hidden_states'] | |
| # import pdb; pdb.set_trace() | |
| res = inf_pipe( | |
| latent = diffused_input, | |
| text_cond = text_cond, | |
| text_uncond = self.encode_text(['']), | |
| img_cond = img_cond, | |
| text_cfg = self.text_cfg, | |
| img_cfg = self.img_cfg, | |
| hdr_cfg = self.hdr_cfg | |
| ) | |
| latent_pred = res['latent'] | |
| image_pred = self.decode_latent_to_image(latent_pred) | |
| res_dict = { | |
| 'pred': image_pred, | |
| } | |
| else: | |
| from .inference.inference import InferenceIP2PVideoHDR | |
| inf_pipe = InferenceIP2PVideoHDR( | |
| self.unet, | |
| beta_start=self.scheduler.config.beta_start, | |
| beta_end=self.scheduler.config.beta_end, | |
| beta_schedule=self.scheduler.config.beta_schedule, | |
| num_ddim_steps=20 | |
| ) | |
| # import pdb; pdb.set_trace() | |
| processed_batch = self.process_batch(batch, mode='val') | |
| diffused_input = torch.randn_like(processed_batch['diffused_input']) #(1,16,4,64,64) | |
| condition = processed_batch['condition'] # 这边其实留有一个接口给condition (1,16,8,64,64) | |
| model_kwargs = processed_batch['model_kwargs'] | |
| img_cond = condition | |
| text_cond = model_kwargs['encoder_hidden_states']['encoder_hidden_states'] | |
| hdr_cond = model_kwargs['encoder_hidden_states']['hdr_latents'] | |
| # import pdb; pdb.set_trace() | |
| res = inf_pipe( | |
| latent = diffused_input, | |
| text_cond = text_cond, | |
| text_uncond = self.encode_text(['']), | |
| hdr_cond = hdr_cond, | |
| img_cond = img_cond, | |
| text_cfg = self.text_cfg, | |
| img_cfg = self.img_cfg, | |
| ) | |
| latent_pred = res['latent'] | |
| image_pred = self.decode_latent_to_image(latent_pred) | |
| res_dict = { | |
| 'pred': image_pred, | |
| } | |
| return res_dict | |
| def configure_optimizers(self): | |
| # optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.optim_args['lr']) | |
| import bitsandbytes as bnb | |
| params = [] | |
| for name, p in self.unet.named_parameters(): | |
| if ('transformer_in' in name) or ('temp_' in name): | |
| # p.requires_grad = True | |
| params.append(p) | |
| else: | |
| pass | |
| # p.requires_grad = False | |
| optimizer = bnb.optim.Adam8bit(params, lr=self.optim_args['lr'], betas=(0.9, 0.999)) | |
| return optimizer | |
| def initialize_unet(self, unet_init_weights): | |
| if unet_init_weights is not None: | |
| print(f'INFO: initialize denoising UNet from {unet_init_weights}') | |
| sd = torch.load(unet_init_weights, map_location='cpu') | |
| model_sd = self.unet.state_dict() | |
| # fit input conv size | |
| for k in model_sd.keys(): | |
| if k in sd.keys(): | |
| pass | |
| else: | |
| # handling temporal layers | |
| if (('temp_' in k) or ('transformer_in' in k)) and 'proj_out' in k: | |
| # print(f'INFO: initialize {k} from {model_sd[k].shape} to zeros') | |
| sd[k] = torch.zeros_like(model_sd[k]) | |
| else: | |
| # print(f'INFO: initialize {k} from {model_sd[k].shape} to random') | |
| sd[k] = model_sd[k] | |
| self.unet.load_state_dict(sd) | |
| class InstructP2PVideoTrainerTemporal(InstructP2PVideoTrainer): | |
| def initialize_unet(self, unet_init_weights): # 这边对比上一级来说, 新加的部分在于 rewrite了unet的load函数 | |
| if unet_init_weights is not None: | |
| print(f'INFO: initialize denoising UNet from {unet_init_weights}') | |
| sd_init_weights, motion_module_init_weights, iclight_init_weights = unet_init_weights | |
| os.makedirs(sd_init_weights, exist_ok=True) | |
| sd_init_weights, motion_module_init_weights, iclight_init_weights = f'models/{sd_init_weights}', f'models/{motion_module_init_weights}', f'models/{iclight_init_weights}' | |
| if not os.path.exists(sd_init_weights): | |
| url = 'https://huggingface.co/stablediffusionapi/realistic-vision-v51/resolve/main/unet/diffusion_pytorch_model.safetensors' | |
| download_url_to_file(url=url, dst=sd_init_weights) | |
| if not os.path.exists(motion_module_init_weights): | |
| url = 'https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc.pth' | |
| download_url_to_file(url=url, dst=motion_module_init_weights) | |
| if not os.path.exists(iclight_init_weights): | |
| url = 'https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors' | |
| download_url_to_file(url=url, dst=iclight_init_weights) | |
| sd = load_file(sd_init_weights) #! 关于加载iclight的unet, 后面再加到yaml里面... 我甚至觉得只要改unet, vae和text其实都差不多 | |
| # sd = torch.load(sd_init_weights, map_location='cpu') # 注意debug看看这是啥 + 打印一下原有和加载的keys | |
| if self.unet.use_motion_module: | |
| motion_sd = torch.load(motion_module_init_weights, map_location='cpu') | |
| assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}' #! 注意一下这行保证了加载的key至少在数量上是对应的; 这行的目的是self.unet是自己定义的 而这两个加载的是别的地方训练的(可能是diffusers中的) | |
| sd.update(motion_sd) | |
| for k, v in self.unet.state_dict().items(): | |
| if 'pos_encoder.pe' in k: # 这边是原来iv2v的代码 temporal_position_encoding_max_len, 设置为 32 | |
| sd[k] = v # the size of pe may change, 主要是temporal layer的size会发生改变... √ 由于输入的max_len变了 | |
| # if 'conv_in.weight' in k: #! tmp, 这里是test一下 | |
| # sd[k] = v | |
| else: | |
| assert len(sd) == len(self.unet.state_dict()) | |
| self.unet.load_state_dict(sd) # 为什么这里可以完美适配? √ | |
| # todo: 更改sd的conv_in.weight的shape到12; 更改函数forward, 支持多个输入cond; iclight的sd_offset加载进去; | |
| unet = self.unet # saVe一下 | |
| # 这里是更改conv_in的shape; #! 这边注意一下要改成3D版本的unet | |
| with torch.no_grad(): | |
| # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) | |
| new_conv_in = InflatedConv3d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) | |
| new_conv_in.bias = unet.conv_in.bias | |
| unet.conv_in = new_conv_in | |
| ###### -- 更改 forward函数 --- ##### | |
| # 这里是更改forward函数。 具体调用的部分在main后面,那里也得改 | |
| # unet_original_forward = unet.forward | |
| # def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): | |
| # c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) # (1,8,67,120) | |
| # c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) # (2,8,67,120) 应该是复制一份,用于cfg | |
| # new_sample = torch.cat([sample, c_concat], dim=2) #(2,12,67,120) 这边还是在通道维度上进行的concat #! change 在第二维cat (2,1,12,67,120) | |
| # # todo 这边中间可以加一个f的通道 b,c,f,h,w ; 另一种方式: 对于数据进行改变, 那么上述concat的代码也需要进行变换了... | |
| # # new_sample = new_sample.unsqueeze(2) # (2,12,1,67,120) #! 这里需要change, 要在一输入之前就要更改他的维度, 因此前面concat也需要稍微改一下 不要在forward中增加f维度 (因为要依赖输入) | |
| # new_sample = rearrange(new_sample, 'b f c h w -> b c f h w') | |
| # kwargs['cross_attention_kwargs'] = {} | |
| # # return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) | |
| # result = unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) | |
| # # return (result[0].squeeze(2),) #! tmp | |
| # return (rearrange(result[0], 'b c f h w -> b f c h w'),) | |
| # unet.forward = hooked_unet_forward | |
| ##### -- 更改 forward函数 --- ##### | |
| # model_path = '/home/fy/Code/instruct-video-to-video/IC-Light/models/iclight_sd15_fbc.safetensors' | |
| # 这里是加载iclight的lora weight | |
| sd_offset = load_file(iclight_init_weights) | |
| sd_origin = unet.state_dict() | |
| keys = sd_origin.keys() | |
| for k in sd_offset.keys(): | |
| sd_origin[k] = sd_origin[k] + sd_offset[k] | |
| # sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} | |
| self.unet.load_state_dict(sd_origin, strict=True) | |
| del sd_offset, sd_origin, unet, keys | |
| # print(1) | |
| # todo 试写一下iclight unet的加载方式 | |
| # sd = load_file('/home/fy/Code/IC-Light/cache_models/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a/unet/diffusion_pytorch_model.safetensors') | |
| # debug: print参数 | |
| # with open('logs/sd_keys.txt', 'w') as f: | |
| # f.write("SD Keys:\n") | |
| # for key in sd_ori.keys(): | |
| # f.write(f"{key}\n") | |
| # unet_state_dict = self.unet.state_dict() | |
| # with open('logs/unet_state_dict_keys.txt', 'w') as f: | |
| # f.write("UNet State Dict Keys:\n") | |
| # for key in unet_state_dict.keys(): | |
| # f.write(f"{key}\n") | |
| else: | |
| with torch.no_grad(): | |
| new_conv_in = InflatedConv3d(12, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) | |
| self.unet.conv_in = new_conv_in | |
| def configure_optimizers(self): # 决定了仅仅训练motion_module的参数 注意一下pl.Trainer独有的函数 | |
| import bitsandbytes as bnb | |
| motion_params = [] | |
| remaining_params = [] | |
| train_names = [] # for debug | |
| for name, p in self.unet.named_parameters(): | |
| if ('motion' in name): #! 哦哦 这里决定了哪些参数用于训练... 这里实际训练的只有motion相关参数 | |
| motion_params.append(p) | |
| train_names.append(name) | |
| elif ('attentions' in name): | |
| motion_params.append(p) | |
| train_names.append(name) | |
| else: | |
| remaining_params.append(p) | |
| # import pdb; pdb.set_trace() | |
| optimizer = bnb.optim.Adam8bit([ | |
| {'params': motion_params, 'lr': self.optim_args['lr']}, | |
| ], betas=(0.9, 0.999)) | |
| return optimizer | |
| class InstructP2PVideoTrainerTemporalText(InstructP2PVideoTrainerTemporal): | |
| def initialize_unet(self, unet_init_weights): # 这边对比上一级来说, 新加的部分在于 rewrite了unet的load函数 | |
| if unet_init_weights is not None: | |
| print(f'INFO: initialize denoising UNet from {unet_init_weights}') | |
| sd_init_weights, motion_module_init_weights, iclight_init_weights = unet_init_weights | |
| if self.base_path: | |
| sd_init_weights = f"{self.base_path}/{sd_init_weights}" | |
| if '.safetensors' in sd_init_weights: # .safetensors的加载方式 | |
| sd = load_file(sd_init_weights) #! 关于加载iclight的unet, 后面再加到yaml里面... 我甚至觉得只要改unet, vae和text其实都差不多 | |
| else: #'.ckpt'场景 | |
| sd = torch.load(sd_init_weights, map_location='cpu') | |
| # sd = torch.load(sd_init_weights, map_location='cpu') # 注意debug看看这是啥 + 打印一下原有和加载的keys | |
| if self.unet.use_motion_module: | |
| motion_sd = torch.load(motion_module_init_weights, map_location='cpu') | |
| assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}' #! 注意一下这行保证了加载的key至少在数量上是对应的; 这行的目的是self.unet是自己定义的 而这两个加载的是别的地方训练的(可能是diffusers中的) | |
| sd.update(motion_sd) | |
| for k, v in self.unet.state_dict().items(): | |
| if 'pos_encoder.pe' in k: # 这边是原来iv2v的代码 temporal_position_encoding_max_len, 设置为 32 | |
| sd[k] = v # the size of pe may change, 主要是temporal layer的size会发生改变... √ 由于输入的max_len变了 | |
| # if 'conv_in.weight' in k: #! tmp, 这里是test一下 | |
| # sd[k] = v | |
| else: | |
| assert len(sd) == len(self.unet.state_dict()) | |
| self.unet.load_state_dict(sd) # 为什么这里可以完美适配? √ | |
| # todo: 更改sd的conv_in.weight的shape到12; 更改函数forward, 支持多个输入cond; iclight的sd_offset加载进去; | |
| unet = self.unet # saVe一下 | |
| # 这里是更改conv_in的shape; #! 这边注意一下要改成3D版本的unet | |
| with torch.no_grad(): | |
| # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) | |
| new_conv_in = InflatedConv3d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) | |
| new_conv_in.bias = unet.conv_in.bias | |
| unet.conv_in = new_conv_in | |
| # model_path = '/home/fy/Code/instruct-video-to-video/IC-Light/models/iclight_sd15_fbc.safetensors' | |
| # 这里是加载iclight的lora weight | |
| sd_offset = load_file(iclight_init_weights) | |
| sd_origin = unet.state_dict() | |
| keys = sd_origin.keys() | |
| for k in sd_offset.keys(): | |
| sd_origin[k] = sd_origin[k] + sd_offset[k] | |
| # sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} | |
| self.unet.load_state_dict(sd_origin, strict=True) | |
| del sd_offset, sd_origin, unet, keys | |
| else: | |
| with torch.no_grad(): | |
| new_conv_in = InflatedConv3d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) | |
| self.unet.conv_in = new_conv_in | |