import gradio as gr import spaces from huggingface_hub import hf_hub_download import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import pyvene as pv from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights #zero = torch.Tensor([0]).cuda() #print(zero.device) # <-- 'cpu' #@spaces.GPU #def greet(n): # print(zero.device) # <-- 'cuda:0' # return f"Hello {zero + n} Tensor" @spaces.GPU def launch_app(): @spaces.GPU # Function to process user input to the app def process_user_input(prompt, concept): # Check if prompt or concept are empty if not prompt or not concept: return f"
" + response_html + "
" + documentation_html return output_html # Set model, interpreter, dictionary choices model_name = "google/gemma-2-2b-it" interpreter_name = "pyvene/gemma-reft-r1-2b-it-res" interpreter_path = "l20/weight.pt" interpreter_component = "model.layers[20].output" dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl" # Interpreter class class Encoder(pv.CollectIntervention): def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear( self.embed_dim, kwargs["latent_dim"], bias=False) def forward(self, base, source=None, subspaces=None): return torch.relu(self.proj(base)) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto').to('cuda') # Load fast model inference pipeline pipe = pipeline( task="text-generation", model=model_name, use_fast=True ) path_to_params = hf_hub_download( repo_id=interpreter_name, filename=interpreter_path, force_download=False, ) params = torch.load(path_to_params) encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).cuda() encoder.proj.weight.data = params.float() pv_model = pv.IntervenableModel({ "component": interpreter_component, "intervention": encoder}, model=model).cuda() # Load dictionary all_concepts = get_concepts_dictionary(dictionary_url) description_text = """ ## Does an LLM Think Like You? Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt. Examples: - **Prompt**: What is 2+2? **Concept**: math - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food """ with gr.Blocks() as demo: gr.Markdown(description_text) with gr.Row(): prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.") concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food") process_button = gr.Button("See if an LLM thinks like you!") output_html = gr.HTML() process_button.click( process_user_input, inputs=[prompt_input, concept_input], outputs=output_html ) demo.launch() if __name__ == "__main__": launch_app()