Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| from model import ModelWrapper | |
| from agents import ExperimentAgent, ExplanationAgent | |
| from database import DB | |
| import matplotlib.pyplot as plt | |
| import io | |
| import base64 | |
| # Initialize components | |
| MODEL_NAME = "gpt2" | |
| model = ModelWrapper(MODEL_NAME) | |
| db = DB("experiments.db") | |
| exp_agent = ExperimentAgent(model, db) | |
| expl_agent = ExplanationAgent() | |
| def run_experiment(prompt, experiment_type, top_k, max_length): | |
| # 1) generate | |
| gen_text = model.generate_text(prompt, max_length=max_length, top_k=top_k) | |
| # 2) run layer importance analysis (proxy for activation patching) | |
| layer_scores = model.layer_importance(prompt, experiment_type=experiment_type) | |
| # 3) save to DB | |
| exp_id = db.save_experiment(prompt, gen_text, layer_scores) | |
| # 4) explanation | |
| explanation = expl_agent.explain_layer_importance(layer_scores) | |
| # 5) heatmap figure | |
| fig = plt.figure(figsize=(6,1.5)) | |
| ax = fig.add_subplot(111) | |
| ax.imshow([layer_scores], aspect='auto') | |
| ax.set_yticks([]) | |
| ax.set_xlabel('Layer') | |
| ax.set_title('Layer importance (proxy)') | |
| buf = io.BytesIO() | |
| fig.tight_layout() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| return gen_text, explanation, buf | |
| demo = gr.Interface( | |
| fn=run_experiment, | |
| inputs=[ | |
| gr.Textbox(lines=3, label="Prompt", placeholder="Enter a sentence or prompt..."), | |
| gr.Radio(choices=["story_continuation", "sentence_completion", "token_prediction"], value="story_continuation", label="Experiment type"), | |
| gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Top-k (generation)"), | |
| gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Max generation length") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Generated text"), | |
| gr.Textbox(label="Explanation"), | |
| gr.Image(type="pil", label="Layer importance heatmap") | |
| ], | |
| title="Mechanistic Analysis Prototype (GPT-2 + Layer Importance)", | |
| description="Quick prototype: GPT-2 generation + layer importance (proxy for activation patching) + SQLite logging" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |