AmnaHassan's picture
Update app.py
7d9d543 verified
# app.py
import gradio as gr
from model import ModelWrapper
from agents import ExperimentAgent, ExplanationAgent
from database import DB
import matplotlib.pyplot as plt
import io
import base64
# Initialize components
MODEL_NAME = "gpt2"
model = ModelWrapper(MODEL_NAME)
db = DB("experiments.db")
exp_agent = ExperimentAgent(model, db)
expl_agent = ExplanationAgent()
def run_experiment(prompt, experiment_type, top_k, max_length):
# 1) generate
gen_text = model.generate_text(prompt, max_length=max_length, top_k=top_k)
# 2) run layer importance analysis (proxy for activation patching)
layer_scores = model.layer_importance(prompt, experiment_type=experiment_type)
# 3) save to DB
exp_id = db.save_experiment(prompt, gen_text, layer_scores)
# 4) explanation
explanation = expl_agent.explain_layer_importance(layer_scores)
# 5) heatmap figure
fig = plt.figure(figsize=(6,1.5))
ax = fig.add_subplot(111)
ax.imshow([layer_scores], aspect='auto')
ax.set_yticks([])
ax.set_xlabel('Layer')
ax.set_title('Layer importance (proxy)')
buf = io.BytesIO()
fig.tight_layout()
fig.savefig(buf, format='png')
buf.seek(0)
return gen_text, explanation, buf
demo = gr.Interface(
fn=run_experiment,
inputs=[
gr.Textbox(lines=3, label="Prompt", placeholder="Enter a sentence or prompt..."),
gr.Radio(choices=["story_continuation", "sentence_completion", "token_prediction"], value="story_continuation", label="Experiment type"),
gr.Slider(minimum=1, maximum=50, step=1, value=10, label="Top-k (generation)"),
gr.Slider(minimum=10, maximum=200, step=1, value=50, label="Max generation length")
],
outputs=[
gr.Textbox(label="Generated text"),
gr.Textbox(label="Explanation"),
gr.Image(type="pil", label="Layer importance heatmap")
],
title="Mechanistic Analysis Prototype (GPT-2 + Layer Importance)",
description="Quick prototype: GPT-2 generation + layer importance (proxy for activation patching) + SQLite logging"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)