Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import torch | |
| import numpy as np | |
| class Edit: | |
| old: str | |
| new: str | |
| weight: float = 1.0 | |
| class Insert: | |
| text: str | |
| weight: float = 1.0 | |
| def old(self): | |
| return "" | |
| def new(self): | |
| return self.text | |
| class Delete: | |
| text: str | |
| weight: float = 1.0 | |
| def old(self): | |
| return self.text | |
| def new(self): | |
| return "" | |
| class Text: | |
| text: str | |
| weight: float = 1.0 | |
| def old(self): | |
| return self.text | |
| def new(self): | |
| return self.text | |
| def get_text_embedding(prompt, tokenizer, text_encoder): | |
| text_input_ids = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0] | |
| return text_embeddings | |
| def encode_text(text_pieces, tokenizer, text_encoder): | |
| n_old_tokens = 0 | |
| n_new_tokens = 0 | |
| new_id_to_old_id = [] | |
| weights = [] | |
| for piece in text_pieces: | |
| old, new = piece.old, piece.new | |
| old_tokens = tokenizer.tokenize(old) | |
| new_tokens = tokenizer.tokenize(new) | |
| if len(old_tokens) == 0 and len(new_tokens) == 0: | |
| continue | |
| elif old == new: | |
| n_old_tokens += len(old_tokens) | |
| n_new_tokens += len(new_tokens) | |
| new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) | |
| elif len(old_tokens) == 0: | |
| # insert | |
| new_id_to_old_id.extend([-1] * len(new_tokens)) | |
| n_new_tokens += len(new_tokens) | |
| elif len(new_tokens) == 0: | |
| # delete | |
| n_old_tokens += len(old_tokens) | |
| else: | |
| # replace | |
| n_old_tokens += len(old_tokens) | |
| n_new_tokens += len(new_tokens) | |
| start = n_old_tokens - len(old_tokens) | |
| end = n_old_tokens | |
| ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) | |
| new_id_to_old_id.extend(list(ids)) | |
| weights.extend([piece.weight] * len(new_tokens)) | |
| old_prompt = " ".join([piece.old for piece in text_pieces]) | |
| new_prompt = " ".join([piece.new for piece in text_pieces]) | |
| old_text_input_ids = tokenizer( | |
| old_prompt, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| new_text_input_ids = tokenizer( | |
| new_prompt, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0] | |
| new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0] | |
| value = new_text_embeddings.clone() # batch (1), seq, dim | |
| key = new_text_embeddings.clone() | |
| for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): | |
| if 0 <= j < old_text_embeddings.shape[1]: | |
| key[0, i] = old_text_embeddings[0, j] | |
| value[0, i] *= weight | |
| return key, value | |
| def get_text_embedding_openclip(prompt, text_encoder, device='cuda'): | |
| import open_clip | |
| text_input_ids = open_clip.tokenize(prompt) | |
| text_embeddings = text_encoder(text_input_ids.to(device)) | |
| return text_embeddings | |
| def encode_text_openclip(text_pieces, text_encoder, device='cuda'): | |
| import open_clip | |
| n_old_tokens = 0 | |
| n_new_tokens = 0 | |
| new_id_to_old_id = [] | |
| weights = [] | |
| for piece in text_pieces: | |
| old, new = piece.old, piece.new | |
| old_tokens = open_clip.tokenize(old) | |
| new_tokens = open_clip.tokenize(new) | |
| if len(old_tokens) == 0 and len(new_tokens) == 0: | |
| continue | |
| elif old == new: | |
| n_old_tokens += len(old_tokens) | |
| n_new_tokens += len(new_tokens) | |
| new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) | |
| elif len(old_tokens) == 0: | |
| # insert | |
| new_id_to_old_id.extend([-1] * len(new_tokens)) | |
| n_new_tokens += len(new_tokens) | |
| elif len(new_tokens) == 0: | |
| # delete | |
| n_old_tokens += len(old_tokens) | |
| else: | |
| # replace | |
| n_old_tokens += len(old_tokens) | |
| n_new_tokens += len(new_tokens) | |
| start = n_old_tokens - len(old_tokens) | |
| end = n_old_tokens | |
| ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) | |
| new_id_to_old_id.extend(list(ids)) | |
| weights.extend([piece.weight] * len(new_tokens)) | |
| old_prompt = " ".join([piece.old for piece in text_pieces]) | |
| new_prompt = " ".join([piece.new for piece in text_pieces]) | |
| old_text_input_ids = open_clip.tokenize(old_prompt) | |
| new_text_input_ids = open_clip.tokenize(new_prompt) | |
| old_text_embeddings = text_encoder(old_text_input_ids.to(device)) | |
| new_text_embeddings = text_encoder(new_text_input_ids.to(device)) | |
| value = new_text_embeddings.clone() # batch (1), seq, dim | |
| key = new_text_embeddings.clone() | |
| for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): | |
| if 0 <= j < old_text_embeddings.shape[1]: | |
| key[0, i] = old_text_embeddings[0, j] | |
| value[0, i] *= weight | |
| return key, value |