Spaces:
Runtime error
Runtime error
| import time | |
| from typing import List, Optional, Union, Any, Dict, Tuple, Literal | |
| import sys, os | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
| sys.path.append(os.path.dirname(__file__)) | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from diffusers import LCMScheduler, StableDiffusionPipeline | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( | |
| retrieve_latents, | |
| ) | |
| from streamdiffusion.image_filter import SimilarImageFilter | |
| class StreamDiffusion: | |
| def __init__( | |
| self, | |
| pipe: StableDiffusionPipeline, | |
| t_index_list: List[int], | |
| torch_dtype: torch.dtype = torch.float16, | |
| width: int = 512, | |
| height: int = 512, | |
| do_add_noise: bool = True, | |
| use_denoising_batch: bool = True, | |
| frame_buffer_size: int = 1, | |
| cfg_type: Literal["none", "full", "self", "initialize"] = "self", | |
| ) -> None: | |
| self.device = pipe.device | |
| self.dtype = torch_dtype | |
| self.generator = None | |
| self.height = height | |
| self.width = width | |
| self.latent_height = int(height // pipe.vae_scale_factor) | |
| self.latent_width = int(width // pipe.vae_scale_factor) | |
| self.frame_bff_size = frame_buffer_size | |
| self.denoising_steps_num = len(t_index_list) | |
| self.cfg_type = cfg_type | |
| if use_denoising_batch: | |
| self.batch_size = self.denoising_steps_num * frame_buffer_size | |
| if self.cfg_type == "initialize": | |
| self.trt_unet_batch_size = ( | |
| self.denoising_steps_num + 1 | |
| ) * self.frame_bff_size | |
| elif self.cfg_type == "full": | |
| self.trt_unet_batch_size = ( | |
| 2 * self.denoising_steps_num * self.frame_bff_size | |
| ) | |
| else: | |
| self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size | |
| else: | |
| self.trt_unet_batch_size = self.frame_bff_size | |
| self.batch_size = frame_buffer_size | |
| self.t_list = t_index_list | |
| self.do_add_noise = do_add_noise | |
| self.use_denoising_batch = use_denoising_batch | |
| self.similar_image_filter = False | |
| self.similar_filter = SimilarImageFilter() | |
| self.prev_image_result = None | |
| self.pipe = pipe | |
| self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) | |
| self.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) | |
| self.text_encoder = pipe.text_encoder | |
| self.unet = pipe.unet | |
| self.vae = pipe.vae | |
| self.inference_time_ema = 0 | |
| def load_lcm_lora( | |
| self, | |
| pretrained_model_name_or_path_or_dict: Union[ | |
| str, Dict[str, torch.Tensor] | |
| ] = "latent-consistency/lcm-lora-sdv1-5", | |
| adapter_name: Optional[Any] = None, | |
| **kwargs, | |
| ) -> None: | |
| self.pipe.load_lora_weights( | |
| pretrained_model_name_or_path_or_dict, adapter_name, **kwargs | |
| ) | |
| def load_lora( | |
| self, | |
| pretrained_lora_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
| adapter_name: Optional[Any] = None, | |
| **kwargs, | |
| ) -> None: | |
| self.pipe.load_lora_weights( | |
| pretrained_lora_model_name_or_path_or_dict, adapter_name, **kwargs | |
| ) | |
| def fuse_lora( | |
| self, | |
| fuse_unet: bool = True, | |
| fuse_text_encoder: bool = True, | |
| lora_scale: float = 1.0, | |
| safe_fusing: bool = False, | |
| ) -> None: | |
| self.pipe.fuse_lora( | |
| fuse_unet=fuse_unet, | |
| fuse_text_encoder=fuse_text_encoder, | |
| lora_scale=lora_scale, | |
| safe_fusing=safe_fusing, | |
| ) | |
| def enable_similar_image_filter(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None: | |
| self.similar_image_filter = True | |
| self.similar_filter.set_threshold(threshold) | |
| self.similar_filter.set_max_skip_frame(max_skip_frame) | |
| def disable_similar_image_filter(self) -> None: | |
| self.similar_image_filter = False | |
| def prepare( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 1.2, | |
| delta: float = 1.0, | |
| generator: Optional[torch.Generator] = torch.Generator(), | |
| seed: int = 2, | |
| ) -> None: | |
| self.generator = generator | |
| self.generator.manual_seed(seed) | |
| # initialize x_t_latent (it can be any random tensor) | |
| if self.denoising_steps_num > 1: | |
| self.x_t_latent_buffer = torch.zeros( | |
| ( | |
| (self.denoising_steps_num - 1) * self.frame_bff_size, | |
| 4, | |
| self.latent_height, | |
| self.latent_width, | |
| ), | |
| dtype=self.dtype, | |
| device=self.device, | |
| ) | |
| else: | |
| self.x_t_latent_buffer = None | |
| if self.cfg_type == "none": | |
| self.guidance_scale = 1.0 | |
| else: | |
| self.guidance_scale = guidance_scale | |
| self.delta = delta | |
| do_classifier_free_guidance = False | |
| if self.guidance_scale > 1.0: | |
| do_classifier_free_guidance = True | |
| encoder_output = self.pipe.encode_prompt( | |
| prompt=prompt, | |
| device=self.device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| negative_prompt=negative_prompt, | |
| ) | |
| self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) | |
| if self.use_denoising_batch and self.cfg_type == "full": | |
| uncond_prompt_embeds = encoder_output[1].repeat(self.batch_size, 1, 1) | |
| elif self.cfg_type == "initialize": | |
| uncond_prompt_embeds = encoder_output[1].repeat(self.frame_bff_size, 1, 1) | |
| if self.guidance_scale > 1.0 and ( | |
| self.cfg_type == "initialize" or self.cfg_type == "full" | |
| ): | |
| self.prompt_embeds = torch.cat( | |
| [uncond_prompt_embeds, self.prompt_embeds], dim=0 | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps, self.device) | |
| self.timesteps = self.scheduler.timesteps.to(self.device) | |
| # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list | |
| self.sub_timesteps = [] | |
| for t in self.t_list: | |
| self.sub_timesteps.append(self.timesteps[t]) | |
| sub_timesteps_tensor = torch.tensor( | |
| self.sub_timesteps, dtype=torch.long, device=self.device | |
| ) | |
| self.sub_timesteps_tensor = torch.repeat_interleave( | |
| sub_timesteps_tensor, | |
| repeats=self.frame_bff_size if self.use_denoising_batch else 1, | |
| dim=0, | |
| ) | |
| self.init_noise = torch.randn( | |
| (self.batch_size, 4, self.latent_height, self.latent_width), | |
| generator=generator, | |
| ).to(device=self.device, dtype=self.dtype) | |
| self.stock_noise = torch.zeros_like(self.init_noise) | |
| c_skip_list = [] | |
| c_out_list = [] | |
| for timestep in self.sub_timesteps: | |
| c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete( | |
| timestep | |
| ) | |
| c_skip_list.append(c_skip) | |
| c_out_list.append(c_out) | |
| self.c_skip = ( | |
| torch.stack(c_skip_list) | |
| .view(len(self.t_list), 1, 1, 1) | |
| .to(dtype=self.dtype, device=self.device) | |
| ) | |
| self.c_out = ( | |
| torch.stack(c_out_list) | |
| .view(len(self.t_list), 1, 1, 1) | |
| .to(dtype=self.dtype, device=self.device) | |
| ) | |
| alpha_prod_t_sqrt_list = [] | |
| beta_prod_t_sqrt_list = [] | |
| for timestep in self.sub_timesteps: | |
| alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt() | |
| beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt() | |
| alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt) | |
| beta_prod_t_sqrt_list.append(beta_prod_t_sqrt) | |
| alpha_prod_t_sqrt = ( | |
| torch.stack(alpha_prod_t_sqrt_list) | |
| .view(len(self.t_list), 1, 1, 1) | |
| .to(dtype=self.dtype, device=self.device) | |
| ) | |
| beta_prod_t_sqrt = ( | |
| torch.stack(beta_prod_t_sqrt_list) | |
| .view(len(self.t_list), 1, 1, 1) | |
| .to(dtype=self.dtype, device=self.device) | |
| ) | |
| self.alpha_prod_t_sqrt = torch.repeat_interleave( | |
| alpha_prod_t_sqrt, | |
| repeats=self.frame_bff_size if self.use_denoising_batch else 1, | |
| dim=0, | |
| ) | |
| self.beta_prod_t_sqrt = torch.repeat_interleave( | |
| beta_prod_t_sqrt, | |
| repeats=self.frame_bff_size if self.use_denoising_batch else 1, | |
| dim=0, | |
| ) | |
| def update_prompt(self, prompt: str) -> None: | |
| encoder_output = self.pipe.encode_prompt( | |
| prompt=prompt, | |
| device=self.device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=False, | |
| ) | |
| self.prompt_embeds = encoder_output[0].repeat(self.batch_size, 1, 1) | |
| def add_noise( | |
| self, | |
| original_samples: torch.Tensor, | |
| noise: torch.Tensor, | |
| t_index: int, | |
| ) -> torch.Tensor: | |
| noisy_samples = ( | |
| self.alpha_prod_t_sqrt[t_index] * original_samples | |
| + self.beta_prod_t_sqrt[t_index] * noise | |
| ) | |
| return noisy_samples | |
| def scheduler_step_batch( | |
| self, | |
| model_pred_batch: torch.Tensor, | |
| x_t_latent_batch: torch.Tensor, | |
| idx: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| # TODO: use t_list to select beta_prod_t_sqrt | |
| if idx is None: | |
| F_theta = ( | |
| x_t_latent_batch - self.beta_prod_t_sqrt * model_pred_batch | |
| ) / self.alpha_prod_t_sqrt | |
| denoised_batch = self.c_out * F_theta + self.c_skip * x_t_latent_batch | |
| else: | |
| F_theta = ( | |
| x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch | |
| ) / self.alpha_prod_t_sqrt[idx] | |
| denoised_batch = ( | |
| self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch | |
| ) | |
| return denoised_batch | |
| def unet_step( | |
| self, | |
| x_t_latent: torch.Tensor, | |
| t_list: Union[torch.Tensor, list[int]], | |
| idx: Optional[int] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): | |
| x_t_latent_plus_uc = torch.concat([x_t_latent[0:1], x_t_latent], dim=0) | |
| t_list = torch.concat([t_list[0:1], t_list], dim=0) | |
| elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): | |
| x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) | |
| t_list = torch.concat([t_list, t_list], dim=0) | |
| else: | |
| x_t_latent_plus_uc = x_t_latent | |
| model_pred = self.unet( | |
| x_t_latent_plus_uc, | |
| t_list, | |
| encoder_hidden_states=self.prompt_embeds, | |
| return_dict=False, | |
| )[0] | |
| if self.guidance_scale > 1.0 and (self.cfg_type == "initialize"): | |
| noise_pred_text = model_pred[1:] | |
| self.stock_noise = torch.concat( | |
| [model_pred[0:1], self.stock_noise[1:]], dim=0 | |
| ) # ここコメントアウトでself out cfg | |
| elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): | |
| noise_pred_uncond, noise_pred_text = model_pred.chunk(2) | |
| else: | |
| noise_pred_text = model_pred | |
| if self.guidance_scale > 1.0 and ( | |
| self.cfg_type == "self" or self.cfg_type == "initialize" | |
| ): | |
| noise_pred_uncond = self.stock_noise * self.delta | |
| if self.guidance_scale > 1.0 and self.cfg_type != "none": | |
| model_pred = noise_pred_uncond + self.guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| else: | |
| model_pred = noise_pred_text | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| if self.use_denoising_batch: | |
| denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) | |
| if self.cfg_type == "self" or self.cfg_type == "initialize": | |
| scaled_noise = self.beta_prod_t_sqrt * self.stock_noise | |
| delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx) | |
| alpha_next = torch.concat( | |
| [ | |
| self.alpha_prod_t_sqrt[1:], | |
| torch.ones_like(self.alpha_prod_t_sqrt[0:1]), | |
| ], | |
| dim=0, | |
| ) | |
| delta_x = alpha_next * delta_x | |
| beta_next = torch.concat( | |
| [ | |
| self.beta_prod_t_sqrt[1:], | |
| torch.ones_like(self.beta_prod_t_sqrt[0:1]), | |
| ], | |
| dim=0, | |
| ) | |
| delta_x = delta_x / beta_next | |
| init_noise = torch.concat( | |
| [self.init_noise[1:], self.init_noise[0:1]], dim=0 | |
| ) | |
| self.stock_noise = init_noise + delta_x | |
| else: | |
| # denoised_batch = self.scheduler.step(model_pred, t_list[0], x_t_latent).denoised | |
| denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx) | |
| return denoised_batch, model_pred | |
| def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: | |
| image_tensors = image_tensors.to( | |
| device=self.device, | |
| dtype=self.vae.dtype, | |
| ) | |
| img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) | |
| img_latent = img_latent * self.vae.config.scaling_factor | |
| x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) | |
| return x_t_latent | |
| def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: | |
| output_latent = self.vae.decode( | |
| x_0_pred_out / self.vae.config.scaling_factor, return_dict=False | |
| )[0] | |
| return output_latent | |
| def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: | |
| prev_latent_batch = self.x_t_latent_buffer | |
| if self.use_denoising_batch: | |
| t_list = self.sub_timesteps_tensor | |
| if self.denoising_steps_num > 1: | |
| x_t_latent = torch.cat((x_t_latent, prev_latent_batch), dim=0) | |
| self.stock_noise = torch.cat( | |
| (self.init_noise[0:1], self.stock_noise[:-1]), dim=0 | |
| ) | |
| x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) | |
| if self.denoising_steps_num > 1: | |
| x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) | |
| if self.do_add_noise: | |
| self.x_t_latent_buffer = ( | |
| self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] | |
| + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] | |
| ) | |
| else: | |
| self.x_t_latent_buffer = ( | |
| self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] | |
| ) | |
| else: | |
| x_0_pred_out = x_0_pred_batch | |
| self.x_t_latent_buffer = None | |
| else: | |
| self.init_noise = x_t_latent | |
| for idx, t in enumerate(self.sub_timesteps_tensor): | |
| t = t.view( | |
| 1, | |
| ).repeat( | |
| self.frame_bff_size, | |
| ) | |
| x_0_pred, model_pred = self.unet_step(x_t_latent, t, idx) | |
| if idx < len(self.sub_timesteps_tensor) - 1: | |
| if self.do_add_noise: | |
| x_t_latent = self.alpha_prod_t_sqrt[ | |
| idx + 1 | |
| ] * x_0_pred + self.beta_prod_t_sqrt[ | |
| idx + 1 | |
| ] * torch.randn_like( | |
| x_0_pred, device=self.device, dtype=self.dtype | |
| ) | |
| else: | |
| x_t_latent = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred | |
| x_0_pred_out = x_0_pred | |
| return x_0_pred_out | |
| def __call__( | |
| self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None | |
| ) -> torch.Tensor: | |
| # start = torch.cuda.Event(enable_timing=True) | |
| # end = torch.cuda.Event(enable_timing=True) | |
| # start.record() | |
| if x is not None: | |
| x = self.image_processor.preprocess(x, self.height, self.width).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| if self.similar_image_filter: | |
| x = self.similar_filter(x) | |
| if x is None: | |
| time.sleep(self.inference_time_ema) | |
| return self.prev_image_result | |
| x_t_latent = self.encode_image(x) | |
| else: | |
| # TODO: check the dimension of x_t_latent | |
| x_t_latent = torch.randn((1, 4, self.latent_height, self.latent_width)).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| x_0_pred_out = self.predict_x0_batch(x_t_latent) | |
| x_output = self.decode_image(x_0_pred_out).detach().clone() | |
| self.prev_image_result = x_output | |
| # end.record() | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| # inference_time = start.elapsed_time(end) / 1000 | |
| # self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time | |
| return x_output | |
| def txt2img(self, batch_size: int = 1) -> torch.Tensor: | |
| x_0_pred_out = self.predict_x0_batch( | |
| torch.randn((batch_size, 4, self.latent_height, self.latent_width)).to( | |
| device=self.device, dtype=self.dtype | |
| ) | |
| ) | |
| x_output = self.decode_image(x_0_pred_out).detach().clone() | |
| return x_output | |
| def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: | |
| x_t_latent = torch.randn( | |
| (batch_size, 4, self.latent_height, self.latent_width), | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| model_pred = self.unet( | |
| x_t_latent, | |
| self.sub_timesteps_tensor, | |
| encoder_hidden_states=self.prompt_embeds, | |
| return_dict=False, | |
| )[0] | |
| x_0_pred_out = ( | |
| x_t_latent - self.beta_prod_t_sqrt * model_pred | |
| ) / self.alpha_prod_t_sqrt | |
| return self.decode_image(x_0_pred_out) | |