Spaces:
Sleeping
Sleeping
File size: 3,660 Bytes
78eadb7 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 e365038 7882888 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# model.py
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import numpy as np
class ModelWrapper:
def __init__(self, model_name="gpt2", device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(
model_name, output_hidden_states=True
).to(self.device)
self.model.eval()
# -----------------------------------------------------
# TEXT GENERATION
# -----------------------------------------------------
def generate_text(self, prompt, max_length=50, top_k=10):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_length=len(inputs['input_ids'][0]) + max_length,
do_sample=True,
top_k=top_k,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
# -----------------------------------------------------
# HIDDEN STATES
# -----------------------------------------------------
def _get_hidden_states(self, prompt):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs)
return out.hidden_states
# -----------------------------------------------------
# ACTIVATION PATCHING (LAYER IMPORTANCE)
# -----------------------------------------------------
def layer_importance(self, prompt, experiment_type="story_continuation"):
"""
Computes a simple proxy for activation patching:
For each transformer block:
- Run GPT-2 normally
- Run GPT-2 with that layer's hidden output zeroed
- Compute difference in next-token logits
Returns a list of importance scores normalized 0-1.
"""
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Baseline forward pass
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
baseline_logits = out.logits[0, -1, :].cpu().numpy()
# GPT-2 has 12 layers (gpt2-base)
n_layers = len(self.model.transformer.h)
scores = []
for layer_idx in range(n_layers):
# ---------------------------
# CORRECTED HOOK
# ---------------------------
def hook(module, inp, outp):
"""
GPT-2 block returns:
outp = (hidden_states, present)
We must keep structure intact.
"""
hidden, present = outp
hidden_zero = torch.zeros_like(hidden)
return (hidden_zero, present)
# Register hook
handle = self.model.transformer.h[layer_idx].register_forward_hook(hook)
# Patched forward pass
with torch.no_grad():
out2 = self.model(**inputs)
logits2 = out2.logits[0, -1, :].cpu().numpy()
# L1 difference
diff = np.sum(np.abs(baseline_logits - logits2))
scores.append(float(diff))
handle.remove()
# Normalize 0–1
arr = np.array(scores)
if arr.max() > 0:
arr = (arr - arr.min()) / (arr.max() - arr.min())
else:
arr = arr * 0.0
return arr.tolist()
|