Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Optional, Union, Tuple | |
| import torch | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| import numpy as np | |
| import cv2 | |
| from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
| from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid | |
| from accelerate.logging import get_logger | |
| import tempfile | |
| import argparse | |
| import yaml | |
| import shutil | |
| logger = get_logger(__name__) | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="Training script for CogVideoX using config file.") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| required=True, | |
| help="Path to the YAML config file." | |
| ) | |
| args = parser.parse_args() | |
| with open(args.config, "r") as f: | |
| config = yaml.safe_load(f) | |
| args = argparse.Namespace(**config) | |
| # Convert nested config dict to an argparse.Namespace for easier downstream usage | |
| return args | |
| def atomic_save(save_path, accelerator): | |
| parent = os.path.dirname(save_path) | |
| tmp_dir = tempfile.mkdtemp(dir=parent) | |
| backup_dir = save_path + "_backup" | |
| try: | |
| # Save state into the temp directory | |
| accelerator.save_state(tmp_dir) | |
| # Backup existing save_path if it exists | |
| if os.path.exists(save_path): | |
| os.rename(save_path, backup_dir) | |
| # Atomically move temp directory into place | |
| os.rename(tmp_dir, save_path) | |
| # Clean up the backup directory | |
| if os.path.exists(backup_dir): | |
| shutil.rmtree(backup_dir) | |
| except Exception as e: | |
| # Clean up temp directory on failure | |
| if os.path.exists(tmp_dir): | |
| shutil.rmtree(tmp_dir) | |
| # Restore from backup if replacement failed | |
| if os.path.exists(backup_dir): | |
| if os.path.exists(save_path): | |
| shutil.rmtree(save_path) | |
| os.rename(backup_dir, save_path) | |
| raise e | |
| def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): | |
| # Use DeepSpeed optimzer | |
| if use_deepspeed: | |
| from accelerate.utils import DummyOptim | |
| return DummyOptim( | |
| params_to_optimize, | |
| lr=args.learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| eps=args.adam_epsilon, | |
| weight_decay=args.adam_weight_decay, | |
| ) | |
| # Optimizer creation | |
| supported_optimizers = ["adam", "adamw", "prodigy"] | |
| if args.optimizer not in supported_optimizers: | |
| logger.warning( | |
| f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" | |
| ) | |
| args.optimizer = "adamw" | |
| if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): | |
| logger.warning( | |
| f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " | |
| f"set to {args.optimizer.lower()}" | |
| ) | |
| if args.use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | |
| ) | |
| if args.optimizer.lower() == "adamw": | |
| optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| eps=args.adam_epsilon, | |
| weight_decay=args.adam_weight_decay, | |
| ) | |
| elif args.optimizer.lower() == "adam": | |
| optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| eps=args.adam_epsilon, | |
| weight_decay=args.adam_weight_decay, | |
| ) | |
| elif args.optimizer.lower() == "prodigy": | |
| try: | |
| import prodigyopt | |
| except ImportError: | |
| raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") | |
| optimizer_class = prodigyopt.Prodigy | |
| if args.learning_rate <= 0.1: | |
| logger.warning( | |
| "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" | |
| ) | |
| optimizer = optimizer_class( | |
| params_to_optimize, | |
| lr=args.learning_rate, | |
| betas=(args.adam_beta1, args.adam_beta2), | |
| beta3=args.prodigy_beta3, | |
| weight_decay=args.adam_weight_decay, | |
| eps=args.adam_epsilon, | |
| decouple=args.prodigy_decouple, | |
| use_bias_correction=args.prodigy_use_bias_correction, | |
| safeguard_warmup=args.prodigy_safeguard_warmup, | |
| ) | |
| return optimizer | |
| def prepare_rotary_positional_embeddings( | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| vae_scale_factor_spatial: int = 8, | |
| patch_size: int = 2, | |
| attention_head_dim: int = 64, | |
| device: Optional[torch.device] = None, | |
| base_height: int = 480, | |
| base_width: int = 720, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| grid_height = height // (vae_scale_factor_spatial * patch_size) | |
| grid_width = width // (vae_scale_factor_spatial * patch_size) | |
| base_size_width = base_width // (vae_scale_factor_spatial * patch_size) | |
| base_size_height = base_height // (vae_scale_factor_spatial * patch_size) | |
| grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) | |
| freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
| embed_dim=attention_head_dim, | |
| crops_coords=grid_crops_coords, | |
| grid_size=(grid_height, grid_width), | |
| temporal_size=num_frames, | |
| ) | |
| freqs_cos = freqs_cos.to(device=device) | |
| freqs_sin = freqs_sin.to(device=device) | |
| return freqs_cos, freqs_sin | |
| def _get_t5_prompt_embeds( | |
| tokenizer: T5Tokenizer, | |
| text_encoder: T5EncoderModel, | |
| prompt: Union[str, List[str]], | |
| num_videos_per_prompt: int = 1, | |
| max_sequence_length: int = 226, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") | |
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| _, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def encode_prompt( | |
| tokenizer: T5Tokenizer, | |
| text_encoder: T5EncoderModel, | |
| prompt: Union[str, List[str]], | |
| num_videos_per_prompt: int = 1, | |
| max_sequence_length: int = 226, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| prompt_embeds = _get_t5_prompt_embeds( | |
| tokenizer, | |
| text_encoder, | |
| prompt=prompt, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| text_input_ids=text_input_ids, | |
| ) | |
| return prompt_embeds | |
| def compute_prompt_embeddings( | |
| tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False | |
| ): | |
| if requires_grad: | |
| prompt_embeds = encode_prompt( | |
| tokenizer, | |
| text_encoder, | |
| prompt, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| else: | |
| with torch.no_grad(): | |
| prompt_embeds = encode_prompt( | |
| tokenizer, | |
| text_encoder, | |
| prompt, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| return prompt_embeds | |
| def save_frames_as_pngs(video_array,output_dir, | |
| downsample_spatial=1, # e.g. 2 to halve width & height | |
| downsample_temporal=1): # e.g. 2 to keep every 2nd frame | |
| """ | |
| Save each frame of a (T, H, W, C) numpy array as a PNG with no compression. | |
| """ | |
| assert video_array.ndim == 4 and video_array.shape[-1] == 3, \ | |
| "Expected (T, H, W, C=3) array" | |
| assert video_array.dtype == np.uint8, "Expected uint8 array" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # temporal downsample | |
| frames = video_array[::downsample_temporal] | |
| # compute spatially downsampled size | |
| T, H, W, _ = frames.shape | |
| new_size = (W // downsample_spatial, H // downsample_spatial) | |
| # PNG compression param: 0 = no compression | |
| png_params = [cv2.IMWRITE_PNG_COMPRESSION, 0] | |
| for idx, frame in enumerate(frames): | |
| # frame is RGB; convert to BGR for OpenCV | |
| bgr = frame[..., ::-1] | |
| if downsample_spatial > 1: | |
| bgr = cv2.resize(bgr, new_size, interpolation=cv2.INTER_NEAREST) | |
| filename = os.path.join(output_dir, "frame_{:05d}.png".format(idx)) | |
| success = cv2.imwrite(filename, bgr, png_params) | |
| if not success: | |
| raise RuntimeError("Failed to write frame ") | |