File size: 1,223 Bytes
7796c8a
 
 
 
69d24e3
 
 
7796c8a
69d24e3
 
 
 
 
7796c8a
 
69d24e3
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# agents.py
from typing import List

class ExperimentAgent:
    def __init__(self, model, db):
        self.model = model
        self.db = db

    def run(self, prompt, experiment_type="story_continuation"):
        generated = self.model.generate_text(prompt)
        layer_scores = self.model.layer_importance(prompt, experiment_type=experiment_type)
        exp_id = self.db.save_experiment(prompt, generated, layer_scores)
        return exp_id, generated, layer_scores

class ExplanationAgent:
    def __init__(self):
        pass

    def explain_layer_importance(self, layer_scores: List[float]) -> str:
        # Very simple heuristic explanation: report top-k layers and give short natural-lang summary
        import numpy as np
        arr = np.array(layer_scores)
        if arr.size == 0:
            return "No layer scores available."
        top_idx = arr.argsort()[-3:][::-1]
        top_layers = ", ".join([str(int(i)) for i in top_idx])
        summary = f"Top influencing layers (proxy): {top_layers}. Layers with higher scores changed the model's next-token logits the most when ablated. This suggests they strongly affect immediate generation behavior for the provided prompt."
        return summary