Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import List, Union, Tuple | |
| from tqdm import tqdm | |
| from .inference import Inference | |
| class InferenceDAMO(Inference): | |
| def __call__( | |
| self, | |
| latent: torch.Tensor, | |
| context: torch.Tensor, | |
| uncond_context: torch.Tensor=None, | |
| start_time: int = 0, | |
| null_embedding: List[torch.Tensor]=None, | |
| ): | |
| all_latent = [] | |
| all_pred = [] # x0_hat | |
| do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None)) | |
| for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
| t = int(t) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([latent, latent], dim=0) | |
| if null_embedding is not None: | |
| context_input = torch.cat([null_embedding[i], context], dim=0) | |
| else: | |
| context_input = torch.cat([uncond_context, context], dim=0) | |
| else: | |
| latent_input = latent | |
| context_input = context | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| pred_samples = self.scheduler.step(noise_pred, t, latent) | |
| latent = pred_samples.prev_sample | |
| pred = pred_samples.pred_original_sample | |
| all_latent.append(latent.detach()) | |
| all_pred.append(pred.detach()) | |
| return { | |
| 'latent': latent, | |
| 'all_latent': all_latent, | |
| 'all_pred': all_pred | |
| } | |
| class InferenceDAMO_PTP(Inference): | |
| def infer_old_context(self, latent, context, t, uncond_context=None): | |
| do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([latent, latent], dim=0) | |
| context_input = torch.cat([uncond_context, context], dim=0) | |
| else: | |
| latent_input = latent | |
| context_input = context | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| pred_samples = self.scheduler.step(noise_pred, t, latent) | |
| latent = pred_samples.prev_sample | |
| pred = pred_samples.pred_original_sample | |
| return latent, pred | |
| def infer_new_context(self, latent, context, t, uncond_context=None): | |
| do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([latent, latent], dim=0) | |
| if isinstance(context, (list, tuple)): | |
| context_input = ( | |
| torch.cat([uncond_context, context[0]], dim=0), | |
| torch.cat([uncond_context, context[1]], dim=0), | |
| ) | |
| else: | |
| context_input = torch.cat([uncond_context, context], dim=0) | |
| else: | |
| latent_input = latent | |
| context_input = context | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| pred_samples = self.scheduler.step(noise_pred, t, latent) | |
| latent = pred_samples.prev_sample | |
| pred = pred_samples.pred_original_sample | |
| return latent, pred | |
| def __call__( | |
| self, | |
| latent: torch.Tensor, | |
| context: torch.Tensor, # used when > ca_end_time | |
| old_context: torch.Tensor=None, # used when < sa_end_time | |
| old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time | |
| uncond_context: torch.Tensor=None, | |
| sa_end_time: float=0.3, | |
| ca_end_time: float=0.8, | |
| start_time: int = 0, | |
| ): | |
| assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" | |
| all_latent = [] | |
| all_pred = [] | |
| all_latent_old = [] | |
| all_pred_old = [] | |
| old_latent = latent.clone() | |
| new_latent = latent.clone() | |
| for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
| t = int(t) | |
| old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
| if i < sa_end_time * self.num_ddim_steps: | |
| new_latent_next_t, pred_new = old_latent_next_t, pred_old | |
| elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: | |
| new_latent_next_t, pred_new = self.infer_new_context( | |
| new_latent, old_to_new_context, t, uncond_context | |
| ) | |
| else: | |
| new_latent_next_t, pred_new = self.infer_new_context( | |
| new_latent, context, t, uncond_context | |
| ) | |
| old_latent = old_latent_next_t | |
| new_latent = new_latent_next_t | |
| all_latent.append(new_latent_next_t.detach()) | |
| all_pred.append(pred_new.detach()) | |
| all_latent_old.append(old_latent_next_t.detach()) | |
| all_pred_old.append(pred_old.detach()) | |
| return { | |
| 'latent': new_latent, | |
| 'latent_old': old_latent, | |
| 'all_latent': all_latent, | |
| 'all_pred': all_pred, | |
| 'all_latent_old': all_latent_old, | |
| 'all_pred_old': all_pred_old, | |
| } | |
| class InferenceDAMO_PTP_v2(Inference): | |
| def set_ptp_in_xattn_layers(self, prompt_to_prompt: bool, num_frames=1): | |
| for m in self.unet.modules(): | |
| if m.__class__.__name__ == 'CrossAttention': | |
| m.ptp_sa_replace = prompt_to_prompt | |
| m.num_frames = num_frames | |
| def infer_both_with_sa_replace(self, old_latent, new_latent, old_context, new_context, t, uncond_context=None): | |
| do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([old_latent, new_latent, old_latent, new_latent], dim=0) | |
| context_input = torch.cat([uncond_context, uncond_context, old_context, new_context], dim=0) | |
| else: | |
| latent_input = torch.cat([old_latent, new_latent], dim=0) | |
| context_input = torch.cat([old_context, new_context], dim=0) | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| noise_pred_old, noise_pred_new = noise_pred.chunk(2, dim=0) | |
| pred_samples_old = self.scheduler.step(noise_pred_old, t, old_latent) | |
| pred_samples_new = self.scheduler.step(noise_pred_new, t, new_latent) | |
| old_latent = pred_samples_old.prev_sample | |
| new_latent = pred_samples_new.prev_sample | |
| old_pred = pred_samples_old.pred_original_sample | |
| new_pred = pred_samples_new.pred_original_sample | |
| return old_latent, new_latent, old_pred, new_pred | |
| def infer_old_context(self, latent, context, t, uncond_context=None): | |
| do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([latent, latent], dim=0) | |
| context_input = torch.cat([uncond_context, context], dim=0) | |
| else: | |
| latent_input = latent | |
| context_input = context | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| pred_samples = self.scheduler.step(noise_pred, t, latent) | |
| latent = pred_samples.prev_sample | |
| pred = pred_samples.pred_original_sample | |
| return latent, pred | |
| def infer_new_context(self, latent, context, t, uncond_context=None): | |
| do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) | |
| if do_classifier_free_guidance: | |
| latent_input = torch.cat([latent, latent], dim=0) | |
| if isinstance(context, (list, tuple)): | |
| context_input = ( | |
| torch.cat([uncond_context, context[0]], dim=0), | |
| torch.cat([uncond_context, context[1]], dim=0), | |
| ) | |
| else: | |
| context_input = torch.cat([uncond_context, context], dim=0) | |
| else: | |
| latent_input = latent | |
| context_input = context | |
| noise_pred = self.unet( | |
| latent_input, | |
| torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), | |
| context_input, | |
| ) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| pred_samples = self.scheduler.step(noise_pred, t, latent) | |
| latent = pred_samples.prev_sample | |
| pred = pred_samples.pred_original_sample | |
| return latent, pred | |
| def __call__( | |
| self, | |
| latent: torch.Tensor, | |
| context: torch.Tensor, # used when > ca_end_time | |
| old_context: torch.Tensor=None, # used when < sa_end_time | |
| old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time | |
| uncond_context: torch.Tensor=None, | |
| sa_end_time: float=0.3, | |
| ca_end_time: float=0.8, | |
| start_time: int = 0, | |
| ): | |
| assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" | |
| all_latent = [] | |
| all_pred = [] | |
| all_latent_old = [] | |
| all_pred_old = [] | |
| old_latent = latent.clone() | |
| new_latent = latent.clone() | |
| for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): | |
| t = int(t) | |
| if i < sa_end_time * self.num_ddim_steps: | |
| self.set_ptp_in_xattn_layers(True, num_frames=latent.shape[2]) | |
| old_latent_next_t, new_latent_next_t, pred_old, pred_new = self.infer_both_with_sa_replace( | |
| old_latent, new_latent, old_context, context, t, uncond_context | |
| ) | |
| elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: | |
| self.set_ptp_in_xattn_layers(False) | |
| old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
| new_latent_next_t, pred_new = self.infer_new_context( | |
| new_latent, old_to_new_context, t, uncond_context | |
| ) | |
| else: | |
| self.set_ptp_in_xattn_layers(False) | |
| old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) | |
| new_latent_next_t, pred_new = self.infer_new_context( | |
| new_latent, context, t, uncond_context | |
| ) | |
| old_latent = old_latent_next_t | |
| new_latent = new_latent_next_t | |
| all_latent.append(new_latent_next_t.detach()) | |
| all_pred.append(pred_new.detach()) | |
| all_latent_old.append(old_latent_next_t.detach()) | |
| all_pred_old.append(pred_old.detach()) | |
| return { | |
| 'latent': new_latent, | |
| 'latent_old': old_latent, | |
| 'all_latent': all_latent, | |
| 'all_pred': all_pred, | |
| 'all_latent_old': all_latent_old, | |
| 'all_pred_old': all_pred_old, | |
| } |