import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib.colors as clrs import requests import json import pandas as pd import torch import spaces # Function to get tokens given text @spaces.GPU def get_tokens(tokenizer, text): token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") tokens = tokenizer.convert_ids_to_tokens(token_ids[0]) return tokens, token_ids # Function to apply chat template to prompt @spaces.GPU def decorate_prompt(tokenizer, prompt): chat = [ {"role": "user", "content": prompt}, {"role": "assistant", "content": ""}, ] text = tokenizer.apply_chat_template(chat, tokenize=False) token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda" if torch.cuda.is_available() else "cpu") return token_ids # Function to get response to prompt def get_response(model_pipe, prompt): response = model_pipe(prompt)[0]['generated_text'] return response # Function to highlight tokens based on given values def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None): if len(tokens) != len(values): raise ValueError("The number of tokens and values must be the same.") # Set color map cmap = cm.get_cmap(cmap_name) norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(), vmax=vmax if vmax is not None else values.detach().max()) html_output = f"