Spaces:
Runtime error
Runtime error
| import os | |
| import imageio | |
| import numpy as np | |
| from typing import Union, Optional | |
| import torch | |
| import torchvision | |
| import torch.distributed as dist | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| import cv2 | |
| import math | |
| import moviepy.editor as mpy | |
| from PIL import Image | |
| # We recommend to use the following affinity score(motion magnitude) | |
| # Also encourage to try to construct different score by yourself | |
| RANGE_LIST = [ | |
| [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion | |
| [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion | |
| [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion | |
| [1.0 , 0.9 , 0.85, 0.85, 0.85, 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.85, 0.85, 0.9 , 1.0 ], # Loop | |
| [1.0 , 0.8 , 0.8 , 0.8 , 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8 , 0.8 , 1.0 ], # Loop | |
| [1.0 , 0.8 , 0.7 , 0.7 , 0.7 , 0.7 , 0.6 , 0.5 , 0.5 , 0.6 , 0.7 , 0.7 , 0.7 , 0.7 , 0.8 , 1.0 ], # Loop | |
| [0.5, 0.2], # Style Transfer Large Motion | |
| [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion | |
| [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion | |
| ] | |
| def zero_rank_print(s): | |
| if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) | |
| def save_videos_mp4(video: torch.Tensor, path: str, fps: int=8): | |
| video = rearrange(video, "b c t h w -> t b c h w") | |
| num_frames, batch_size, channels, height, width = video.shape | |
| assert batch_size == 1,\ | |
| 'Only support batch size == 1' | |
| video = video.squeeze(1) | |
| video = rearrange(video, "t c h w -> t h w c") | |
| def make_frame(t): | |
| frame_tensor = video[int(t * fps)] | |
| frame_np = (frame_tensor * 255).numpy().astype('uint8') | |
| return frame_np | |
| clip = mpy.VideoClip(make_frame, duration=num_frames / fps) | |
| clip.write_videofile(path, fps=fps, codec='libx264') | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8) | |
| outputs.append(x) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.mimsave(path, outputs, fps=fps) | |
| # DDIM Inversion | |
| def init_prompt(prompt, pipeline): | |
| uncond_input = pipeline.tokenizer( | |
| [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, | |
| return_tensors="pt" | |
| ) | |
| uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] | |
| text_input = pipeline.tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=pipeline.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] | |
| context = torch.cat([uncond_embeddings, text_embeddings]) | |
| return context | |
| def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, | |
| sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): | |
| timestep, next_timestep = min( | |
| timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep | |
| alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod | |
| alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] | |
| beta_prod_t = 1 - alpha_prod_t | |
| next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
| next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
| next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
| return next_sample | |
| def get_noise_pred_single(latents, t, context, unet): | |
| noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] | |
| return noise_pred | |
| def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): | |
| context = init_prompt(prompt, pipeline) | |
| uncond_embeddings, cond_embeddings = context.chunk(2) | |
| all_latent = [latent] | |
| latent = latent.clone().detach() | |
| for i in tqdm(range(num_inv_steps)): | |
| t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] | |
| noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) | |
| latent = next_step(noise_pred, t, latent, ddim_scheduler) | |
| all_latent.append(latent) | |
| return all_latent | |
| def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): | |
| ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) | |
| return ddim_latents | |
| def prepare_mask_coef(video_length:int, cond_frame:int, sim_range:list=[0.2, 1.0]): | |
| assert len(sim_range) == 2, \ | |
| 'sim_range should has the length of 2, including the min and max similarity' | |
| assert video_length > 1, \ | |
| 'video_length should be greater than 1' | |
| assert video_length > cond_frame,\ | |
| 'video_length should be greater than cond_frame' | |
| diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1) | |
| coef = [1.0] * video_length | |
| for f in range(video_length): | |
| f_diff = diff * abs(cond_frame - f) | |
| f_diff = 1 - f_diff | |
| coef[f] *= f_diff | |
| return coef | |
| def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int): | |
| assert video_length > 0, \ | |
| 'video_length should be greater than 0' | |
| assert video_length > cond_frame,\ | |
| 'video_length should be greater than cond_frame' | |
| range_list = RANGE_LIST | |
| assert sim_range < len(range_list),\ | |
| f'sim_range type{sim_range} not implemented' | |
| coef = range_list[sim_range] | |
| coef = coef + ([coef[-1]] * (video_length - len(coef))) | |
| order = [abs(i - cond_frame) for i in range(video_length)] | |
| coef = [coef[order[i]] for i in range(video_length)] | |
| return coef | |
| def prepare_mask_coef_multi_cond(video_length:int, cond_frames:list, sim_range:list=[0.2, 1.0]): | |
| assert len(sim_range) == 2, \ | |
| 'sim_range should has the length of 2, including the min and max similarity' | |
| assert video_length > 1, \ | |
| 'video_length should be greater than 1' | |
| assert isinstance(cond_frames, list), \ | |
| 'cond_frames should be a list' | |
| assert video_length > max(cond_frames),\ | |
| 'video_length should be greater than cond_frame' | |
| if max(sim_range) == min(sim_range): | |
| cond_coefs = [sim_range[0]] * video_length | |
| return cond_coefs | |
| cond_coefs = [] | |
| for cond_frame in cond_frames: | |
| cond_coef = prepare_mask_coef(video_length, cond_frame, sim_range) | |
| cond_coefs.append(cond_coef) | |
| mixed_coef = [0] * video_length | |
| for conds in range(len(cond_frames)): | |
| for f in range(video_length): | |
| mixed_coef[f] = abs(cond_coefs[conds][f] - mixed_coef[f]) | |
| if conds > 0: | |
| min_num = min(mixed_coef) | |
| max_num = max(mixed_coef) | |
| for f in range(video_length): | |
| mixed_coef[f] = (mixed_coef[f] - min_num) / (max_num - min_num) | |
| mixed_max = max(mixed_coef) | |
| mixed_min = min(mixed_coef) | |
| for f in range(video_length): | |
| mixed_coef[f] = (max(sim_range) - min(sim_range)) * (mixed_coef[f] - mixed_min) / (mixed_max - mixed_min) + min(sim_range) | |
| mixed_coef = [x if min(sim_range) <= x <= max(sim_range) else min(sim_range) if x < min(sim_range) else max(sim_range) for x in mixed_coef] | |
| return mixed_coef | |
| def prepare_masked_latent_cond(video_length: int, cond_frames: list): | |
| for cond_frame in cond_frames: | |
| assert cond_frame < video_length, \ | |
| 'cond_frame should be smaller than video_length' | |
| assert cond_frame > -1, \ | |
| f'cond_frame should be in the range of [0, {video_length}]' | |
| cond_frames.sort() | |
| nearest = [cond_frames[0]] * video_length | |
| for f in range(video_length): | |
| for cond_frame in cond_frames: | |
| if abs(nearest[f] - f) > abs(cond_frame - f): | |
| nearest[f] = cond_frame | |
| maked_latent_cond = nearest | |
| return maked_latent_cond | |
| def estimated_kernel_size(frame_width: int, frame_height: int) -> int: | |
| """Estimate kernel size based on video resolution.""" | |
| # TODO: This equation is based on manual estimation from a few videos. | |
| # Create a more comprehensive test suite to optimize against. | |
| size: int = 4 + round(math.sqrt(frame_width * frame_height) / 192) | |
| if size % 2 == 0: | |
| size += 1 | |
| return size | |
| def detect_edges(lum: np.ndarray) -> np.ndarray: | |
| """Detect edges using the luma channel of a frame. | |
| Arguments: | |
| lum: 2D 8-bit image representing the luma channel of a frame. | |
| Returns: | |
| 2D 8-bit image of the same size as the input, where pixels with values of 255 | |
| represent edges, and all other pixels are 0. | |
| """ | |
| # Initialize kernel. | |
| kernel_size = estimated_kernel_size(lum.shape[1], lum.shape[0]) | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| # Estimate levels for thresholding. | |
| # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc. | |
| sigma: float = 1.0 / 3.0 | |
| median = np.median(lum) | |
| low = int(max(0, (1.0 - sigma) * median)) | |
| high = int(min(255, (1.0 + sigma) * median)) | |
| # Calculate edges using Canny algorithm, and reduce noise by dilating the edges. | |
| # This increases edge overlap leading to improved robustness against noise and slow | |
| # camera movement. Note that very large kernel sizes can negatively affect accuracy. | |
| edges = cv2.Canny(lum, low, high) | |
| return cv2.dilate(edges, kernel) | |
| def prepare_mask_coef_by_score(video_shape: list, cond_frame_idx: list, sim_range: list = [0.2, 1.0], | |
| statistic: list = [1, 100], coef_max: int = 0.98, score: Optional[torch.Tensor] = None): | |
| ''' | |
| the shape of video_data is (b f c h w) | |
| cond_frame_idx is a list, with length of batch_size | |
| the shape of statistic is (f 2) | |
| the shape of score is (b f) | |
| the shape of coef is (b f) | |
| ''' | |
| assert len(video_shape) == 2, \ | |
| f'the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels' | |
| batch_size, frame_num = video_shape[0], video_shape[1] | |
| score = score.permute(0, 2, 1).squeeze(0) | |
| # list -> b 1 | |
| cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1) | |
| statistic = torch.tensor(statistic) | |
| # (f 2) -> (b f 2) | |
| statistic = statistic.repeat(batch_size, 1, 1) | |
| # shape of order (b f), shape of cond_mat (b f) | |
| order = torch.arange(0, frame_num, 1) | |
| order = order.repeat(batch_size, 1) | |
| cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat | |
| order = abs(order - cond_mat) | |
| statistic = statistic[:,order.to(torch.long)][0,:,:,:] | |
| # score (b f) max_s (b f 1) | |
| max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype) | |
| min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype) | |
| score[score > max_stats] = max_stats[score > max_stats] * 0.95 | |
| score[score < min_stats] = min_stats[score < min_stats] | |
| eps = 1e-10 | |
| coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range))) | |
| indices = torch.arange(coef.shape[0]).unsqueeze(1) | |
| coef[indices, cond_fram_mat] = 1.0 | |
| return coef | |
| def preprocess_img(img_path, max_size:int=512): | |
| ori_image = Image.open(img_path).convert('RGB') | |
| width, height = ori_image.size | |
| long_edge = max(width, height) | |
| if long_edge > max_size: | |
| scale_factor = max_size / long_edge | |
| else: | |
| scale_factor = 1 | |
| width = int(width * scale_factor) | |
| height = int(height * scale_factor) | |
| ori_image = ori_image.resize((width, height)) | |
| if (width % 8 != 0) or (height % 8 != 0): | |
| in_width = (width // 8) * 8 | |
| in_height = (height // 8) * 8 | |
| else: | |
| in_width = width | |
| in_height = height | |
| in_image = ori_image | |
| in_image = ori_image.resize((in_width, in_height)) | |
| # in_image = ori_image.resize((512, 512)) | |
| in_image_np = np.array(in_image) | |
| return in_image_np, in_height, in_width |