Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn import Parameter | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, *args, **kwargs): | |
| return self.fn(x, *args, **kwargs) + x | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| class Mish(nn.Module): | |
| def forward(self, x): | |
| return x * torch.tanh(F.softplus(x)) | |
| class Rezero(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.g = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x): | |
| return self.fn(x) * self.g | |
| # building block modules | |
| class Block(nn.Module): | |
| def __init__(self, dim, dim_out, groups=8): | |
| super().__init__() | |
| if groups == 0: | |
| self.block = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(dim, dim_out, 3), | |
| Mish() | |
| ) | |
| else: | |
| self.block = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(dim, dim_out, 3), | |
| nn.GroupNorm(groups, dim_out), | |
| Mish() | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, dim, dim_out, *, time_emb_dim=0, groups=8): | |
| super().__init__() | |
| if time_emb_dim > 0: | |
| self.mlp = nn.Sequential( | |
| Mish(), | |
| nn.Linear(time_emb_dim, dim_out) | |
| ) | |
| self.block1 = Block(dim, dim_out, groups=groups) | |
| self.block2 = Block(dim_out, dim_out, groups=groups) | |
| self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
| def forward(self, x, time_emb=None, cond=None): | |
| h = self.block1(x) | |
| if time_emb is not None: | |
| h += self.mlp(time_emb)[:, :, None, None] | |
| if cond is not None: | |
| h += cond | |
| h = self.block2(h) | |
| return h + self.res_conv(x) | |
| class Upsample(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.ConvTranspose2d(dim, dim, 4, 2, 1), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Downsample(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(dim, dim, 3, 2), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class LinearAttention(nn.Module): | |
| def __init__(self, dim, heads=4, dim_head=32): | |
| super().__init__() | |
| self.heads = heads | |
| hidden_dim = dim_head * heads | |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| qkv = self.to_qkv(x) | |
| q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) | |
| k = k.softmax(dim=-1) | |
| context = torch.einsum('bhdn,bhen->bhde', k, v) | |
| out = torch.einsum('bhde,bhdn->bhen', context, q) | |
| out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) | |
| return self.to_out(out) | |
| class MultiheadAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, | |
| add_bias_kv=False, add_zero_attn=False): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.kdim = kdim if kdim is not None else embed_dim | |
| self.vdim = vdim if vdim is not None else embed_dim | |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
| self.scaling = self.head_dim ** -0.5 | |
| if self.qkv_same_dim: | |
| self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) | |
| else: | |
| self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | |
| self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | |
| self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | |
| if bias: | |
| self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) | |
| else: | |
| self.register_parameter('in_proj_bias', None) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| if add_bias_kv: | |
| self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |
| self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |
| else: | |
| self.bias_k = self.bias_v = None | |
| self.add_zero_attn = add_zero_attn | |
| self.reset_parameters() | |
| self.enable_torch_version = False | |
| if hasattr(F, "multi_head_attention_forward"): | |
| self.enable_torch_version = True | |
| else: | |
| self.enable_torch_version = False | |
| self.last_attn_probs = None | |
| def reset_parameters(self): | |
| if self.qkv_same_dim: | |
| nn.init.xavier_uniform_(self.in_proj_weight) | |
| else: | |
| nn.init.xavier_uniform_(self.k_proj_weight) | |
| nn.init.xavier_uniform_(self.v_proj_weight) | |
| nn.init.xavier_uniform_(self.q_proj_weight) | |
| nn.init.xavier_uniform_(self.out_proj.weight) | |
| if self.in_proj_bias is not None: | |
| nn.init.constant_(self.in_proj_bias, 0.) | |
| nn.init.constant_(self.out_proj.bias, 0.) | |
| if self.bias_k is not None: | |
| nn.init.xavier_normal_(self.bias_k) | |
| if self.bias_v is not None: | |
| nn.init.xavier_normal_(self.bias_v) | |
| def forward( | |
| self, | |
| query, key, value, | |
| key_padding_mask=None, | |
| need_weights=True, | |
| attn_mask=None, | |
| before_softmax=False, | |
| need_head_weights=False, | |
| ): | |
| """Input shape: [B, T, C] | |
| Args: | |
| key_padding_mask (ByteTensor, optional): mask to exclude | |
| keys that are pads, of shape `(batch, src_len)`, where | |
| padding elements are indicated by 1s. | |
| need_weights (bool, optional): return the attention weights, | |
| averaged over heads (default: False). | |
| attn_mask (ByteTensor, optional): typically used to | |
| implement causal attention, where the mask prevents the | |
| attention from looking forward in time (default: None). | |
| before_softmax (bool, optional): return the raw attention | |
| weights and values before the attention softmax. | |
| need_head_weights (bool, optional): return the attention | |
| weights for each head. Implies *need_weights*. Default: | |
| return the average attention weights over all heads. | |
| """ | |
| if need_head_weights: | |
| need_weights = True | |
| query = query.transpose(0, 1) | |
| key = key.transpose(0, 1) | |
| value = value.transpose(0, 1) | |
| tgt_len, bsz, embed_dim = query.size() | |
| assert embed_dim == self.embed_dim | |
| assert list(query.size()) == [tgt_len, bsz, embed_dim] | |
| attn_output, attn_output_weights = F.multi_head_attention_forward( | |
| query, key, value, self.embed_dim, self.num_heads, | |
| self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, | |
| self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, | |
| self.training, key_padding_mask, need_weights, attn_mask) | |
| attn_output = attn_output.transpose(0, 1) | |
| return attn_output, attn_output_weights | |
| def in_proj_qkv(self, query): | |
| return self._in_proj(query).chunk(3, dim=-1) | |
| def in_proj_q(self, query): | |
| if self.qkv_same_dim: | |
| return self._in_proj(query, end=self.embed_dim) | |
| else: | |
| bias = self.in_proj_bias | |
| if bias is not None: | |
| bias = bias[:self.embed_dim] | |
| return F.linear(query, self.q_proj_weight, bias) | |
| def in_proj_k(self, key): | |
| if self.qkv_same_dim: | |
| return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) | |
| else: | |
| weight = self.k_proj_weight | |
| bias = self.in_proj_bias | |
| if bias is not None: | |
| bias = bias[self.embed_dim:2 * self.embed_dim] | |
| return F.linear(key, weight, bias) | |
| def in_proj_v(self, value): | |
| if self.qkv_same_dim: | |
| return self._in_proj(value, start=2 * self.embed_dim) | |
| else: | |
| weight = self.v_proj_weight | |
| bias = self.in_proj_bias | |
| if bias is not None: | |
| bias = bias[2 * self.embed_dim:] | |
| return F.linear(value, weight, bias) | |
| def _in_proj(self, input, start=0, end=None): | |
| weight = self.in_proj_weight | |
| bias = self.in_proj_bias | |
| weight = weight[start:end, :] | |
| if bias is not None: | |
| bias = bias[start:end] | |
| return F.linear(input, weight, bias) | |
| class ResidualDenseBlock_5C(nn.Module): | |
| def __init__(self, nf=64, gc=32, bias=True): | |
| super(ResidualDenseBlock_5C, self).__init__() | |
| # gc: growth channel, i.e. intermediate channels | |
| self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) | |
| self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) | |
| self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) | |
| self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) | |
| self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) | |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
| # initialization | |
| # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | |
| def forward(self, x): | |
| x1 = self.lrelu(self.conv1(x)) | |
| x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |
| x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |
| x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| class RRDB(nn.Module): | |
| '''Residual in Residual Dense Block''' | |
| def __init__(self, nf, gc=32): | |
| super(RRDB, self).__init__() | |
| self.RDB1 = ResidualDenseBlock_5C(nf, gc) | |
| self.RDB2 = ResidualDenseBlock_5C(nf, gc) | |
| self.RDB3 = ResidualDenseBlock_5C(nf, gc) | |
| def forward(self, x): | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out * 0.2 + x | |