File size: 3,660 Bytes
78eadb7
e365038
 
 
 
 
 
 
 
7882888
 
 
e365038
 
7882888
 
 
e365038
 
7882888
e365038
7882888
 
 
 
 
 
 
 
e365038
 
7882888
 
 
e365038
 
 
 
 
 
7882888
 
 
e365038
 
7882888
 
 
 
 
 
e365038
7882888
 
e365038
7882888
 
e365038
 
7882888
e365038
 
7882888
 
e365038
 
 
7882888
 
 
 
e365038
7882888
 
 
 
 
 
 
 
e365038
7882888
e365038
7882888
 
e365038
 
7882888
e365038
7882888
 
e365038
 
7882888
e365038
 
7882888
e365038
 
 
 
 
 
7882888
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# model.py
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import numpy as np

class ModelWrapper:
    def __init__(self, model_name="gpt2", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(
            model_name, output_hidden_states=True
        ).to(self.device)
        self.model.eval()

    # -----------------------------------------------------
    # TEXT GENERATION
    # -----------------------------------------------------
    def generate_text(self, prompt, max_length=50, top_k=10):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            output = self.model.generate(
                **inputs,
                max_length=len(inputs['input_ids'][0]) + max_length,
                do_sample=True,
                top_k=top_k,
                pad_token_id=self.tokenizer.eos_token_id
            )

        return self.tokenizer.decode(output[0], skip_special_tokens=True)

    # -----------------------------------------------------
    # HIDDEN STATES
    # -----------------------------------------------------
    def _get_hidden_states(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            out = self.model(**inputs)
        return out.hidden_states

    # -----------------------------------------------------
    # ACTIVATION PATCHING (LAYER IMPORTANCE)
    # -----------------------------------------------------
    def layer_importance(self, prompt, experiment_type="story_continuation"):
        """
        Computes a simple proxy for activation patching:
        For each transformer block:
            - Run GPT-2 normally
            - Run GPT-2 with that layer's hidden output zeroed
            - Compute difference in next-token logits
        Returns a list of importance scores normalized 0-1.
        """

        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # Baseline forward pass
        with torch.no_grad():
            out = self.model(**inputs, output_hidden_states=True)

        baseline_logits = out.logits[0, -1, :].cpu().numpy()

        # GPT-2 has 12 layers (gpt2-base)
        n_layers = len(self.model.transformer.h)
        scores = []

        for layer_idx in range(n_layers):

            # ---------------------------
            # CORRECTED HOOK
            # ---------------------------
            def hook(module, inp, outp):
                """
                GPT-2 block returns:
                    outp = (hidden_states, present)
                We must keep structure intact.
                """
                hidden, present = outp
                hidden_zero = torch.zeros_like(hidden)
                return (hidden_zero, present)

            # Register hook
            handle = self.model.transformer.h[layer_idx].register_forward_hook(hook)

            # Patched forward pass
            with torch.no_grad():
                out2 = self.model(**inputs)

            logits2 = out2.logits[0, -1, :].cpu().numpy()

            # L1 difference
            diff = np.sum(np.abs(baseline_logits - logits2))
            scores.append(float(diff))

            handle.remove()

        # Normalize 0–1
        arr = np.array(scores)
        if arr.max() > 0:
            arr = (arr - arr.min()) / (arr.max() - arr.min())
        else:
            arr = arr * 0.0

        return arr.tolist()