AmnaHassan commited on
Commit
85acc7e
·
verified ·
1 Parent(s): 0339477

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from model import ModelWrapper
4
+ from agents import ExperimentAgent, ExplanationAgent
5
+ from database import DB
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ import base64
9
+
10
+
11
+ # Initialize components
12
+ MODEL_NAME = "gpt2"
13
+ model = ModelWrapper(MODEL_NAME)
14
+ db = DB("experiments.db")
15
+ exp_agent = ExperimentAgent(model, db)
16
+ expl_agent = ExplanationAgent()
17
+
18
+
19
+
20
+
21
+ def run_experiment(prompt, experiment_type, top_k, max_length):
22
+ # 1) generate
23
+ gen_text = model.generate_text(prompt, max_length=max_length, top_k=top_k)
24
+
25
+
26
+ # 2) run layer importance analysis (proxy for activation patching)
27
+ layer_scores = model.layer_importance(prompt, experiment_type=experiment_type)
28
+
29
+
30
+ # 3) save to DB
31
+ exp_id = db.save_experiment(prompt, gen_text, layer_scores)
32
+
33
+
34
+ # 4) explanation
35
+ explanation = expl_agent.explain_layer_importance(layer_scores)
36
+
37
+
38
+ # 5) heatmap figure
39
+ fig = plt.figure(figsize=(6,1.5))
40
+ ax = fig.add_subplot(111)
41
+ ax.imshow([layer_scores], aspect='auto')
42
+ ax.set_yticks([])
43
+ ax.set_xlabel('Layer')
44
+ ax.set_title('Layer importance (proxy)')
45
+
46
+
47
+ buf = io.BytesIO()
48
+ fig.tight_layout()
49
+ fig.savefig(buf, format='png')
50
+ buf.seek(0)
51
+
52
+
53
+ return gen_text, explanation, buf
54
+
55
+
56
+
57
+
58
+ demo = gr.Interface(
59
+ fn=run_experiment,
60
+ inputs=[
61
+ gr.Textbox(lines=3, label="Prompt", placeholder="Enter a sentence or prompt..."),
62
+ gr.Radio(choices=["story_continuation", "sentence_completion", "token_prediction"], value="story_continuation", label="Experiment type"),
63
+ gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Top-k (generation)"),
64
+ gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Max generation length")
65
+ ],
66
+ outputs=[
67
+ gr.Textbox(label="Generated text"),
68
+ gr.Textbox(label="Explanation"),
69
+ gr.Image(type="pil", label="Layer importance heatmap")
70
+ ],
71
+ title="Mechanistic Analysis Prototype (GPT-2 + Layer Importance)",
72
+ description="Quick prototype: GPT-2 generation + layer importance (proxy for activation patching) + SQLite logging"
73
+ )
74
+
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch(server_name="0.0.0.0", server_port=7860)