Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Callable | |
| import torch | |
| import torch.nn.functional as F | |
| from config import RunConfig | |
| from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX | |
| from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline | |
| from utils import attention_utils | |
| from utils.adain import masked_adain | |
| from utils.model_utils import get_stable_diffusion_model | |
| from utils.segmentation import Segmentor | |
| class AppearanceTransferModel: | |
| def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None): | |
| self.config = config | |
| self.pipe = get_stable_diffusion_model() if pipe is None else pipe | |
| self.register_attention_control() | |
| self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun]) | |
| self.latents_app, self.latents_struct = None, None | |
| self.zs_app, self.zs_struct = None, None | |
| self.image_app_mask_32, self.image_app_mask_64 = None, None | |
| self.image_struct_mask_32, self.image_struct_mask_64 = None, None | |
| self.enable_edit = False | |
| self.step = 0 | |
| def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor): | |
| self.latents_app = latents_app | |
| self.latents_struct = latents_struct | |
| def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor): | |
| self.zs_app = zs_app | |
| self.zs_struct = zs_struct | |
| def set_masks(self, masks: List[torch.Tensor]): | |
| self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks | |
| def get_adain_callback(self): | |
| def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable: | |
| self.step = st | |
| # Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation | |
| if self.step == self.config.adain_range.start: | |
| masks = self.segmentor.get_object_masks() | |
| self.set_masks(masks) | |
| # Apply AdaIN operation using the computed masks | |
| if self.config.adain_range.start <= self.step < self.config.adain_range.end: | |
| latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64) | |
| return callback | |
| def register_attention_control(self): | |
| model_self = self | |
| class AttentionProcessor: | |
| def __init__(self, place_in_unet: str): | |
| self.place_in_unet = place_in_unet | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.") | |
| def __call__(self, | |
| attn, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask=None, | |
| temb=None, | |
| perform_swap: bool = False): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| is_cross = encoder_hidden_states is not None | |
| if not is_cross: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| should_mix = False | |
| # Potentially apply our cross image attention operation | |
| # To do so, we need to be in a self-attention alyer in the decoder part of the denoising network | |
| if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit: | |
| if attention_utils.should_mix_keys_and_values(model_self, hidden_states): | |
| should_mix = True | |
| if model_self.step % 5 == 0 and model_self.step < 40: | |
| # Inject the structure's keys and values | |
| key[OUT_INDEX] = key[STRUCT_INDEX] | |
| value[OUT_INDEX] = value[STRUCT_INDEX] | |
| else: | |
| # Inject the appearance's keys and values | |
| key[OUT_INDEX] = key[STYLE_INDEX] | |
| value[OUT_INDEX] = value[STYLE_INDEX] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Compute the cross attention and apply our contrasting operation | |
| hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention( | |
| query, key, value, | |
| edit_map=perform_swap and model_self.enable_edit and should_mix, | |
| is_cross=is_cross, | |
| contrast_strength=model_self.config.contrast_strength, | |
| ) | |
| # Update attention map for segmentation | |
| if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1: | |
| model_self.segmentor.update_attention(attn_weight, is_cross) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query[OUT_INDEX].dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def register_recr(net_, count, place_in_unet): | |
| if net_.__class__.__name__ == 'ResnetBlock2D': | |
| pass | |
| if net_.__class__.__name__ == 'Attention': | |
| net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}")) | |
| return count + 1 | |
| elif hasattr(net_, 'children'): | |
| for net__ in net_.children(): | |
| count = register_recr(net__, count, place_in_unet) | |
| return count | |
| cross_att_count = 0 | |
| sub_nets = self.pipe.unet.named_children() | |
| for net in sub_nets: | |
| if "down" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "down") | |
| elif "up" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "up") | |
| elif "mid" in net[0]: | |
| cross_att_count += register_recr(net[1], 0, "mid") | |