AmnaHassan's picture
Update model.py
7882888 verified
# model.py
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import numpy as np
class ModelWrapper:
def __init__(self, model_name="gpt2", device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(
model_name, output_hidden_states=True
).to(self.device)
self.model.eval()
# -----------------------------------------------------
# TEXT GENERATION
# -----------------------------------------------------
def generate_text(self, prompt, max_length=50, top_k=10):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_length=len(inputs['input_ids'][0]) + max_length,
do_sample=True,
top_k=top_k,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
# -----------------------------------------------------
# HIDDEN STATES
# -----------------------------------------------------
def _get_hidden_states(self, prompt):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs)
return out.hidden_states
# -----------------------------------------------------
# ACTIVATION PATCHING (LAYER IMPORTANCE)
# -----------------------------------------------------
def layer_importance(self, prompt, experiment_type="story_continuation"):
"""
Computes a simple proxy for activation patching:
For each transformer block:
- Run GPT-2 normally
- Run GPT-2 with that layer's hidden output zeroed
- Compute difference in next-token logits
Returns a list of importance scores normalized 0-1.
"""
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Baseline forward pass
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
baseline_logits = out.logits[0, -1, :].cpu().numpy()
# GPT-2 has 12 layers (gpt2-base)
n_layers = len(self.model.transformer.h)
scores = []
for layer_idx in range(n_layers):
# ---------------------------
# CORRECTED HOOK
# ---------------------------
def hook(module, inp, outp):
"""
GPT-2 block returns:
outp = (hidden_states, present)
We must keep structure intact.
"""
hidden, present = outp
hidden_zero = torch.zeros_like(hidden)
return (hidden_zero, present)
# Register hook
handle = self.model.transformer.h[layer_idx].register_forward_hook(hook)
# Patched forward pass
with torch.no_grad():
out2 = self.model(**inputs)
logits2 = out2.logits[0, -1, :].cpu().numpy()
# L1 difference
diff = np.sum(np.abs(baseline_logits - logits2))
scores.append(float(diff))
handle.remove()
# Normalize 0–1
arr = np.array(scores)
if arr.max() > 0:
arr = (arr - arr.min()) / (arr.max() - arr.min())
else:
arr = arr * 0.0
return arr.tolist()