Spaces:
Runtime error
Runtime error
| from types import MethodType | |
| from typing import Optional | |
| from diffusers.models.attention_processor import Attention | |
| import torch | |
| import torch.nn.functional as F | |
| from .feature import * | |
| from .utils import * | |
| def convolution_forward( # From <class 'diffusers.models.resnet.ResnetBlock2D'>, forward (diffusers==0.28.0) | |
| self, | |
| input_tensor: torch.Tensor, | |
| temb: torch.Tensor, | |
| *args, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| do_structure_control = self.do_control and self.t in self.structure_schedule | |
| hidden_states = input_tensor | |
| hidden_states = self.norm1(hidden_states) | |
| hidden_states = self.nonlinearity(hidden_states) | |
| if self.upsample is not None: | |
| # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
| if hidden_states.shape[0] >= 64: | |
| input_tensor = input_tensor.contiguous() | |
| hidden_states = hidden_states.contiguous() | |
| input_tensor = self.upsample(input_tensor) | |
| hidden_states = self.upsample(hidden_states) | |
| elif self.downsample is not None: | |
| input_tensor = self.downsample(input_tensor) | |
| hidden_states = self.downsample(hidden_states) | |
| hidden_states = self.conv1(hidden_states) | |
| if self.time_emb_proj is not None: | |
| if not self.skip_time_act: | |
| temb = self.nonlinearity(temb) | |
| temb = self.time_emb_proj(temb)[:, :, None, None] | |
| if self.time_embedding_norm == "default": | |
| if temb is not None: | |
| hidden_states = hidden_states + temb | |
| hidden_states = self.norm2(hidden_states) | |
| elif self.time_embedding_norm == "scale_shift": | |
| if temb is None: | |
| raise ValueError( | |
| f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" | |
| ) | |
| time_scale, time_shift = torch.chunk(temb, 2, dim=1) | |
| hidden_states = self.norm2(hidden_states) | |
| hidden_states = hidden_states * (1 + time_scale) + time_shift | |
| else: | |
| hidden_states = self.norm2(hidden_states) | |
| hidden_states = self.nonlinearity(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = self.conv2(hidden_states) | |
| # Feature injection and AdaIN (hidden_states) | |
| if do_structure_control and "hidden_states" in self.structure_target: | |
| hidden_states = feature_injection(hidden_states, batch_order=self.batch_order) | |
| if self.conv_shortcut is not None: | |
| input_tensor = self.conv_shortcut(input_tensor) | |
| output_tensor = (input_tensor + hidden_states) / self.output_scale_factor | |
| # Feature injection and AdaIN (output_tensor) | |
| if do_structure_control and "output_tensor" in self.structure_target: | |
| output_tensor = feature_injection(output_tensor, batch_order=self.batch_order) | |
| return output_tensor | |
| class AttnProcessor2_0: # From <class 'diffusers.models.attention_processor.AttnProcessor2_0'> (diffusers==0.28.0) | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| *args, | |
| **kwargs, | |
| ) -> torch.FloatTensor: | |
| do_structure_control = attn.do_control and attn.t in attn.structure_schedule | |
| do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| no_encoder_hidden_states = encoder_hidden_states is None | |
| if no_encoder_hidden_states: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| if do_appearance_control: # Assume we only have this for self attention | |
| hidden_states_normed = normalize(hidden_states, dim=-2) # B H D C | |
| encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2) | |
| query_normed = attn.to_q(hidden_states_normed) | |
| key_normed = attn.to_k(encoder_hidden_states_normed) | |
| inner_dim = key_normed.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Match query and key injection with structure injection (if injection is happening this layer) | |
| if do_structure_control: | |
| if "query" in attn.structure_target: | |
| query_normed = feature_injection(query_normed, batch_order=attn.batch_order) | |
| if "key" in attn.structure_target: | |
| key_normed = feature_injection(key_normed, batch_order=attn.batch_order) | |
| # Appearance transfer (before) | |
| if do_appearance_control and "before" in attn.appearance_target: | |
| hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| if no_encoder_hidden_states: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Feature injection (query, key, and/or value) | |
| if do_structure_control: | |
| if "query" in attn.structure_target: | |
| query = feature_injection(query, batch_order=attn.batch_order) | |
| if "key" in attn.structure_target: | |
| key = feature_injection(key, batch_order=attn.batch_order) | |
| if "value" in attn.structure_target: | |
| value = feature_injection(value, batch_order=attn.batch_order) | |
| # Appearance transfer (value) | |
| if do_appearance_control and "value" in attn.appearance_target: | |
| value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order) | |
| # The output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| # Appearance transfer (after) | |
| if do_appearance_control and "after" in attn.appearance_target: | |
| hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # Linear projection | |
| hidden_states = attn.to_out[0](hidden_states, *args) | |
| # Dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def register_control( | |
| model, | |
| timesteps, | |
| control_schedule, # structure_conv, structure_attn, appearance_attn | |
| control_target = [["output_tensor"], ["query", "key"], ["before"]], | |
| ): | |
| # Assume timesteps in reverse order (T -> 0) | |
| for block_type in ["encoder", "decoder", "middle"]: | |
| blocks = { | |
| "encoder": model.unet.down_blocks, | |
| "decoder": model.unet.up_blocks, | |
| "middle": [model.unet.mid_block], | |
| }[block_type] | |
| control_schedule_block = control_schedule[block_type] | |
| if block_type == "middle": | |
| control_schedule_block = [control_schedule_block] | |
| for layer in range(len(control_schedule_block)): | |
| # Convolution | |
| num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0 | |
| for block in range(num_blocks): | |
| convolution = blocks[layer].resnets[block] | |
| convolution.structure_target = control_target[0] | |
| convolution.structure_schedule = get_schedule( | |
| timesteps, get_elem(control_schedule_block[layer][0], block) | |
| ) | |
| convolution.forward = MethodType(convolution_forward, convolution) | |
| # Self-attention | |
| num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0 | |
| for block in range(num_blocks): | |
| for transformer_block in blocks[layer].attentions[block].transformer_blocks: | |
| attention = transformer_block.attn1 | |
| attention.structure_target = control_target[1] | |
| attention.structure_schedule = get_schedule( | |
| timesteps, get_elem(control_schedule_block[layer][1], block) | |
| ) | |
| attention.appearance_target = control_target[2] | |
| attention.appearance_schedule = get_schedule( | |
| timesteps, get_elem(control_schedule_block[layer][2], block) | |
| ) | |
| attention.processor = AttnProcessor2_0() | |
| def register_attr(model, t, do_control, batch_order): | |
| for layer_type in ["encoder", "decoder", "middle"]: | |
| blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks, | |
| "middle": [model.unet.mid_block]}[layer_type] | |
| for layer in blocks: | |
| # Convolution | |
| for module in layer.resnets: | |
| module.t = t | |
| module.do_control = do_control | |
| module.batch_order = batch_order | |
| # Self-attention | |
| if hasattr(layer, "attentions"): | |
| for block in layer.attentions: | |
| for module in block.transformer_blocks: | |
| module.attn1.t = t | |
| module.attn1.do_control = do_control | |
| module.attn1.batch_order = batch_order | |