File size: 2,158 Bytes
85acc7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82f6668
 
7d9d543
82f6668
 
7d9d543
82f6668
 
7d9d543
82f6668
 
7d9d543
82f6668
 
 
 
 
 
 
7d9d543
82f6668
 
 
 
85acc7e
7d9d543
85acc7e
 
 
82f6668
 
7d9d543
 
 
 
82f6668
 
7d9d543
 
 
82f6668
 
 
7d9d543
85acc7e
 
7d9d543
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# app.py
import gradio as gr
from model import ModelWrapper
from agents import ExperimentAgent, ExplanationAgent
from database import DB
import matplotlib.pyplot as plt
import io
import base64

# Initialize components
MODEL_NAME = "gpt2"
model = ModelWrapper(MODEL_NAME)
db = DB("experiments.db")
exp_agent = ExperimentAgent(model, db)
expl_agent = ExplanationAgent()


def run_experiment(prompt, experiment_type, top_k, max_length):
    # 1) generate
    gen_text = model.generate_text(prompt, max_length=max_length, top_k=top_k)

    # 2) run layer importance analysis (proxy for activation patching)
    layer_scores = model.layer_importance(prompt, experiment_type=experiment_type)

    # 3) save to DB
    exp_id = db.save_experiment(prompt, gen_text, layer_scores)

    # 4) explanation
    explanation = expl_agent.explain_layer_importance(layer_scores)

    # 5) heatmap figure
    fig = plt.figure(figsize=(6,1.5))
    ax = fig.add_subplot(111)
    ax.imshow([layer_scores], aspect='auto')
    ax.set_yticks([])
    ax.set_xlabel('Layer')
    ax.set_title('Layer importance (proxy)')

    buf = io.BytesIO()
    fig.tight_layout()
    fig.savefig(buf, format='png')
    buf.seek(0)

    return gen_text, explanation, buf


demo = gr.Interface(
    fn=run_experiment,
    inputs=[
        gr.Textbox(lines=3, label="Prompt", placeholder="Enter a sentence or prompt..."),
        gr.Radio(choices=["story_continuation", "sentence_completion", "token_prediction"], value="story_continuation", label="Experiment type"),
        gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Top-k (generation)"),
        gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Max generation length")
    ],
    outputs=[
        gr.Textbox(label="Generated text"),
        gr.Textbox(label="Explanation"),
        gr.Image(type="pil", label="Layer importance heatmap")
    ],
    title="Mechanistic Analysis Prototype (GPT-2 + Layer Importance)",
    description="Quick prototype: GPT-2 generation + layer importance (proxy for activation patching) + SQLite logging"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)