Spaces:
Runtime error
Runtime error
| import abc | |
| import torch | |
| from torch import inference_mode | |
| from tqdm import tqdm | |
| """ | |
| Inversion code taken from: | |
| 1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion | |
| 2. The LEDITS demo: https://huggingface.co/spaces/editing-images/ledits/tree/main | |
| """ | |
| LOW_RESOURCE = True | |
| def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1): | |
| # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf, | |
| # based on the code in https://github.com/inbarhub/DDPM_inversion | |
| # returns wt, zs, wts: | |
| # wt - inverted latent | |
| # wts - intermediate inverted latents | |
| # zs - noise maps | |
| pipe.scheduler.set_timesteps(num_diffusion_steps) | |
| with inference_mode(): | |
| w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float() | |
| wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, | |
| prog_bar=True, num_inference_steps=num_diffusion_steps) | |
| return zs, wts | |
| def inversion_forward_process(model, x0, | |
| etas=None, | |
| prog_bar=False, | |
| prompt="", | |
| cfg_scale=3.5, | |
| num_inference_steps=50, eps=None | |
| ): | |
| if not prompt == "": | |
| text_embeddings = encode_text(model, prompt) | |
| uncond_embedding = encode_text(model, "") | |
| timesteps = model.scheduler.timesteps.to(model.device) | |
| variance_noise_shape = ( | |
| num_inference_steps, | |
| model.unet.in_channels, | |
| model.unet.sample_size, | |
| model.unet.sample_size) | |
| if etas is None or (type(etas) in [int, float] and etas == 0): | |
| eta_is_zero = True | |
| zs = None | |
| else: | |
| eta_is_zero = False | |
| if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps | |
| xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps) | |
| alpha_bar = model.scheduler.alphas_cumprod | |
| zs = torch.zeros(size=variance_noise_shape, device=model.device) | |
| t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
| xt = x0 | |
| op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps) | |
| for t in op: | |
| idx = t_to_idx[int(t)] | |
| # 1. predict noise residual | |
| if not eta_is_zero: | |
| xt = xts[idx][None] | |
| with torch.no_grad(): | |
| out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding) | |
| if not prompt == "": | |
| cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings) | |
| if not prompt == "": | |
| ## classifier free guidance | |
| noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) | |
| else: | |
| noise_pred = out.sample | |
| if eta_is_zero: | |
| # 2. compute more noisy image and set x_t -> x_t+1 | |
| xt = forward_step(model, noise_pred, t, xt) | |
| else: | |
| xtm1 = xts[idx + 1][None] | |
| # pred of x0 | |
| pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 | |
| # direction to xt | |
| prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps | |
| alpha_prod_t_prev = model.scheduler.alphas_cumprod[ | |
| prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod | |
| variance = get_variance(model, t) | |
| pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred | |
| mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
| z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) | |
| zs[idx] = z | |
| # correction to avoid error accumulation | |
| xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z | |
| xts[idx + 1] = xtm1 | |
| if not zs is None: | |
| zs[-1] = torch.zeros_like(zs[-1]) | |
| return xt, zs, xts | |
| def encode_text(model, prompts): | |
| text_input = model.tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=model.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
| return text_encoding | |
| def sample_xts_from_x0(model, x0, num_inference_steps=50): | |
| """ | |
| Samples from P(x_1:T|x_0) | |
| """ | |
| # torch.manual_seed(43256465436) | |
| alpha_bar = model.scheduler.alphas_cumprod | |
| sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5 | |
| alphas = model.scheduler.alphas | |
| betas = 1 - alphas | |
| variance_noise_shape = ( | |
| num_inference_steps, | |
| model.unet.in_channels, | |
| model.unet.sample_size, | |
| model.unet.sample_size) | |
| timesteps = model.scheduler.timesteps.to(model.device) | |
| t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
| xts = torch.zeros(variance_noise_shape).to(x0.device) | |
| for t in reversed(timesteps): | |
| idx = t_to_idx[int(t)] | |
| xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t] | |
| xts = torch.cat([xts, x0], dim=0) | |
| return xts | |
| def forward_step(model, model_output, timestep, sample): | |
| next_timestep = min(model.scheduler.config.num_train_timesteps - 2, | |
| timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps) | |
| # 2. compute alphas, betas | |
| alpha_prod_t = model.scheduler.alphas_cumprod[timestep] | |
| beta_prod_t = 1 - alpha_prod_t | |
| # 3. compute predicted original sample from predicted noise also called | |
| # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
| next_sample = model.scheduler.add_noise(pred_original_sample, | |
| model_output, | |
| torch.LongTensor([next_timestep])) | |
| return next_sample | |
| def get_variance(model, timestep): | |
| prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps | |
| alpha_prod_t = model.scheduler.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = model.scheduler.alphas_cumprod[ | |
| prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
| return variance | |
| class AttentionControl(abc.ABC): | |
| def step_callback(self, x_t): | |
| return x_t | |
| def between_steps(self): | |
| return | |
| def num_uncond_att_layers(self): | |
| return self.num_att_layers if LOW_RESOURCE else 0 | |
| def forward(self, attn, is_cross: bool, place_in_unet: str): | |
| raise NotImplementedError | |
| def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
| if self.cur_att_layer >= self.num_uncond_att_layers: | |
| if LOW_RESOURCE: | |
| attn = self.forward(attn, is_cross, place_in_unet) | |
| else: | |
| h = attn.shape[0] | |
| attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | |
| self.cur_att_layer += 1 | |
| if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
| self.cur_att_layer = 0 | |
| self.cur_step += 1 | |
| self.between_steps() | |
| return attn | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_att_layers = -1 | |
| self.cur_att_layer = 0 | |
| class AttentionStore(AttentionControl): | |
| def get_empty_store(): | |
| return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
| "down_self": [], "mid_self": [], "up_self": []} | |
| def forward(self, attn, is_cross: bool, place_in_unet: str): | |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
| if attn.shape[1] <= 32 ** 2: # avoid memory overhead | |
| self.step_store[key].append(attn) | |
| return attn | |
| def between_steps(self): | |
| if len(self.attention_store) == 0: | |
| self.attention_store = self.step_store | |
| else: | |
| for key in self.attention_store: | |
| for i in range(len(self.attention_store[key])): | |
| self.attention_store[key][i] += self.step_store[key][i] | |
| self.step_store = self.get_empty_store() | |
| def get_average_attention(self): | |
| average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in | |
| self.attention_store} | |
| return average_attention | |
| def reset(self): | |
| super(AttentionStore, self).reset() | |
| self.step_store = self.get_empty_store() | |
| self.attention_store = {} | |
| def __init__(self): | |
| super(AttentionStore, self).__init__() | |
| self.step_store = self.get_empty_store() | |
| self.attention_store = {} | |
| def register_attention_control(model, controller): | |
| def ca_forward(self, place_in_unet): | |
| to_out = self.to_out | |
| if type(to_out) is torch.nn.modules.container.ModuleList: | |
| to_out = self.to_out[0] | |
| else: | |
| to_out = self.to_out | |
| def forward(x, context=None, mask=None): | |
| batch_size, sequence_length, dim = x.shape | |
| h = self.heads | |
| q = self.to_q(x) | |
| is_cross = context is not None | |
| context = context if is_cross else x | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q = self.reshape_heads_to_batch_dim(q) | |
| k = self.reshape_heads_to_batch_dim(k) | |
| v = self.reshape_heads_to_batch_dim(v) | |
| sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | |
| if mask is not None: | |
| mask = mask.reshape(batch_size, -1) | |
| max_neg_value = -torch.finfo(sim.dtype).max | |
| mask = mask[:, None, :].repeat(h, 1, 1) | |
| sim.masked_fill_(~mask, max_neg_value) | |
| # attention, what we cannot get enough of | |
| attn = sim.softmax(dim=-1) | |
| attn = controller(attn, is_cross, place_in_unet) | |
| out = torch.einsum("b i j, b j d -> b i d", attn, v) | |
| out = self.reshape_batch_dim_to_heads(out) | |
| return to_out(out) | |
| return forward | |
| class DummyController: | |
| def __call__(self, *args): | |
| return args[0] | |
| def __init__(self): | |
| self.num_att_layers = 0 | |
| if controller is None: | |
| controller = DummyController() | |
| def register_recr(net_, count, place_in_unet): | |
| if net_.__class__.__name__ == 'CrossAttention': | |
| net_.forward = ca_forward(net_, place_in_unet) | |
| return count + 1 | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| count = register_recr(net__, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| sub_nets = model.unet.named_children() | |
| for net in sub_nets: | |
| if "down" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "down") | |
| elif "up" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "up") | |
| elif "mid" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "mid") | |
| controller.num_att_layers = cross_att_count | |