Spaces:
Runtime error
Runtime error
| from typing import Tuple, List | |
| import nltk | |
| import numpy as np | |
| import torch | |
| from sklearn.cluster import KMeans | |
| from constants import STYLE_INDEX, STRUCT_INDEX | |
| nltk.download('punkt') | |
| nltk.download('averaged_perceptron_tagger') | |
| """ | |
| Self-segmentation technique taken from Prompt Mixing: https://github.com/orpatashnik/local-prompt-mixing | |
| """ | |
| class Segmentor: | |
| def __init__(self, prompt: str, object_nouns: List[str], num_segments: int = 5, res: int = 32): | |
| self.prompt = prompt | |
| self.num_segments = num_segments | |
| self.resolution = res | |
| self.object_nouns = object_nouns | |
| tokenized_prompt = nltk.word_tokenize(prompt) | |
| forbidden_words = [word.upper() for word in ["photo", "image", "picture"]] | |
| self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) | |
| if pos[:2] == 'NN' and word.upper() not in forbidden_words] | |
| def update_attention(self, attn, is_cross): | |
| res = int(attn.shape[2] ** 0.5) | |
| if is_cross: | |
| if res == 16: | |
| self.cross_attention_32 = attn | |
| elif res == 32: | |
| self.cross_attention_64 = attn | |
| else: | |
| if res == 32: | |
| self.self_attention_32 = attn | |
| elif res == 64: | |
| self.self_attention_64 = attn | |
| def __call__(self, *args, **kwargs): | |
| clusters = self.cluster() | |
| cluster2noun = self.cluster2noun(clusters) | |
| return cluster2noun | |
| def cluster(self, res: int = 32): | |
| np.random.seed(1) | |
| self_attn = self.self_attention_32 if res == 32 else self.self_attention_64 | |
| style_attn = self_attn[STYLE_INDEX].mean(dim=0).cpu().numpy() | |
| style_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(style_attn) | |
| style_clusters = style_kmeans.labels_.reshape(res, res) | |
| struct_attn = self_attn[STRUCT_INDEX].mean(dim=0).cpu().numpy() | |
| struct_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(struct_attn) | |
| struct_clusters = struct_kmeans.labels_.reshape(res, res) | |
| return style_clusters, struct_clusters | |
| def cluster2noun(self, clusters, cross_attn, attn_index): | |
| result = {} | |
| res = int(cross_attn.shape[2] ** 0.5) | |
| nouns_indices = [index for (index, word) in self.nouns] | |
| cross_attn = cross_attn[attn_index].mean(dim=0).reshape(res, res, -1) | |
| nouns_maps = cross_attn.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]] | |
| normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1) | |
| for i in range(nouns_maps.shape[-1]): | |
| curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1) | |
| normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max() | |
| max_score = 0 | |
| all_scores = [] | |
| for c in range(self.num_segments): | |
| cluster_mask = np.zeros_like(clusters) | |
| cluster_mask[clusters == c] = 1 | |
| score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
| scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
| all_scores.append(max(scores)) | |
| max_score = max(max(scores), max_score) | |
| all_scores.remove(max_score) | |
| mean_score = sum(all_scores) / len(all_scores) | |
| for c in range(self.num_segments): | |
| cluster_mask = np.zeros_like(clusters) | |
| cluster_mask[clusters == c] = 1 | |
| score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))] | |
| scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps] | |
| result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > 1.4 * mean_score else "BG" | |
| return result | |
| def create_mask(self, clusters, cross_attention, attn_index): | |
| cluster2noun = self.cluster2noun(clusters, cross_attention, attn_index) | |
| mask = clusters.copy() | |
| obj_segments = [c for c in cluster2noun if cluster2noun[c][1] in self.object_nouns] | |
| for c in range(self.num_segments): | |
| mask[clusters == c] = 1 if c in obj_segments else 0 | |
| return torch.from_numpy(mask).to("cuda") | |
| def get_object_masks(self) -> Tuple[torch.Tensor]: | |
| clusters_style_32, clusters_struct_32 = self.cluster(res=32) | |
| clusters_style_64, clusters_struct_64 = self.cluster(res=64) | |
| mask_style_32 = self.create_mask(clusters_style_32, self.cross_attention_32, STYLE_INDEX) | |
| mask_struct_32 = self.create_mask(clusters_struct_32, self.cross_attention_32, STRUCT_INDEX) | |
| mask_style_64 = self.create_mask(clusters_style_64, self.cross_attention_64, STYLE_INDEX) | |
| mask_struct_64 = self.create_mask(clusters_struct_64, self.cross_attention_64, STRUCT_INDEX) | |
| return mask_style_32, mask_struct_32, mask_style_64, mask_struct_64 | |