""" DistilBERT-based student agent with online learning and memory decay. Uses DistilBERT for Multiple Choice to answer reading comprehension tasks. Implements online learning (fine-tune on 1 example at a time). """ import torch from torch.optim import AdamW from transformers import ( DistilBertForMultipleChoice, DistilBertTokenizer, ) from typing import List, Dict import numpy as np from collections import defaultdict from interfaces import StudentAgentInterface, StudentState, Task from memory_decay import MemoryDecayModel class StudentAgent(StudentAgentInterface): """ DistilBERT-based student that learns reading comprehension. Features: - Online learning (1 example at a time) - Memory decay (Ebbinghaus forgetting) - Per-topic skill tracking - Gradient accumulation for stability """ def __init__( self, learning_rate: float = 5e-5, retention_constant: float = 80.0, device: str = 'cpu', max_length: int = 256, gradient_accumulation_steps: int = 4 ): """ Args: learning_rate: LM fine-tuning learning rate retention_constant: Forgetting speed (higher = slower forgetting) device: 'cpu' or 'cuda' max_length: Max tokens for passage + question + choices gradient_accumulation_steps: Accumulate gradients for stability """ self.device = device self.max_length = max_length self.gradient_accumulation_steps = gradient_accumulation_steps # Load DistilBERT for multiple choice # Allow silent mode for testing verbose = True # Can be overridden try: if verbose: print("Loading DistilBERT model...", end=" ", flush=True) self.model = DistilBertForMultipleChoice.from_pretrained( "distilbert-base-uncased" ).to(self.device) self.tokenizer = DistilBertTokenizer.from_pretrained( "distilbert-base-uncased" ) if verbose: print("✅") except Exception as e: if verbose: print(f"⚠️ (Model unavailable, using dummy mode)") self.model = None self.tokenizer = None # Optimizer if self.model: self.optimizer = AdamW(self.model.parameters(), lr=learning_rate) else: self.optimizer = None # Memory decay model self.memory = MemoryDecayModel(retention_constant=retention_constant) # Track per-topic base skills (before forgetting) self.topic_base_skills: Dict[str, float] = {} # Track learning history self.topic_attempts: Dict[str, int] = defaultdict(int) self.topic_correct: Dict[str, int] = defaultdict(int) # Gradient accumulation counter self.grad_step = 0 # Training mode flag if self.model: self.model.train() def answer(self, task: Task) -> int: """ Predict answer without updating weights. Prediction accuracy is modulated by memory decay. """ if not self.model: # Dummy model: random guessing return np.random.randint(0, 4) self.model.eval() # Prepare inputs inputs = self._prepare_inputs(task) with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits predicted_idx = torch.argmax(logits, dim=-1).item() # Apply memory decay to prediction # If student has forgotten, prediction becomes more random effective_skill = self.memory.get_effective_skill(task.topic) # Probability of using learned answer vs random guess # MCQ baseline = 0.25 (random guessing) use_learned_prob = 0.25 + 0.75 * effective_skill if np.random.random() < use_learned_prob: return predicted_idx else: # Random guess return np.random.randint(0, 4) def learn(self, task: Task) -> bool: """ Fine-tune on a single task (online learning). Returns: True if prediction was correct, False otherwise """ if not self.model: # Dummy learning: track statistics only predicted = np.random.randint(0, 4) was_correct = (predicted == task.answer) self._update_stats(task, was_correct) return was_correct self.model.train() # Get prediction before learning predicted = self.answer(task) was_correct = (predicted == task.answer) # Prepare inputs with correct answer inputs = self._prepare_inputs(task) inputs['labels'] = torch.tensor([task.answer], device=self.device) # Forward pass outputs = self.model(**inputs) loss = outputs.loss # Backward pass with gradient accumulation loss = loss / self.gradient_accumulation_steps loss.backward() self.grad_step += 1 # Update weights every N steps if self.grad_step % self.gradient_accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() # Update statistics self._update_stats(task, was_correct) return was_correct def _update_stats(self, task: Task, was_correct: bool): """Update topic statistics and memory.""" self.topic_attempts[task.topic] += 1 if was_correct: self.topic_correct[task.topic] += 1 # Compute base skill (accuracy without forgetting) base_skill = self.topic_correct[task.topic] / self.topic_attempts[task.topic] self.topic_base_skills[task.topic] = base_skill # Update memory (record practice) self.memory.update_practice(task.topic, base_skill) def evaluate(self, eval_tasks: List[Task]) -> float: """ Evaluate on held-out tasks without updating weights. Returns: Accuracy (0.0-1.0) """ if not eval_tasks: return 0.0 if not self.model: # Dummy evaluation: return random return 0.25 self.model.eval() correct = 0 for task in eval_tasks: predicted = self.answer(task) if predicted == task.answer: correct += 1 return correct / len(eval_tasks) def get_state(self) -> StudentState: """ Get current state for teacher observation. Returns per-topic accuracies accounting for forgetting. """ topic_accuracies = {} time_since_practice = {} for topic in self.topic_base_skills: # Get effective skill (with forgetting) effective_skill = self.memory.get_effective_skill(topic) # Convert to expected accuracy on MCQ topic_accuracies[topic] = 0.25 + 0.75 * effective_skill # Time since last practice time_since_practice[topic] = self.memory.get_time_since_practice(topic) return StudentState( topic_accuracies=topic_accuracies, topic_attempts=dict(self.topic_attempts), time_since_practice=time_since_practice, total_timesteps=sum(self.topic_attempts.values()), current_time=self.memory.current_time ) def _prepare_inputs(self, task: Task) -> Dict[str, torch.Tensor]: """ Prepare inputs for DistilBERT multiple choice model. Format: [CLS] passage [SEP] question [SEP] choice [SEP] Repeated for each of 4 choices. """ if not self.tokenizer: return {} # Create 4 input sequences (one per choice) input_texts = [] for choice in task.choices: # Format: passage + question + choice text = f"{task.passage} {task.question} {choice}" input_texts.append(text) # Tokenize encoded = self.tokenizer( input_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt' ) # Reshape for multiple choice format # (batch_size=1, num_choices=4, seq_length) input_ids = encoded['input_ids'].unsqueeze(0).to(self.device) attention_mask = encoded['attention_mask'].unsqueeze(0).to(self.device) return { 'input_ids': input_ids, 'attention_mask': attention_mask } def advance_time(self, delta: float = 1.0): """Advance time for memory decay.""" self.memory.advance_time(delta) def save(self, path: str): """Save model checkpoint.""" if not self.model: print("⚠️ No model to save (using dummy model)") return torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() if self.optimizer else None, 'topic_base_skills': self.topic_base_skills, 'topic_attempts': dict(self.topic_attempts), 'topic_correct': dict(self.topic_correct), 'memory': self.memory, 'grad_step': self.grad_step }, path) print(f"💾 Saved checkpoint to {path}") def load(self, path: str): """Load model checkpoint.""" checkpoint = torch.load(path, map_location=self.device) if self.model: self.model.load_state_dict(checkpoint['model_state_dict']) if self.optimizer and checkpoint.get('optimizer_state_dict'): self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.topic_base_skills = checkpoint['topic_base_skills'] self.topic_attempts = defaultdict(int, checkpoint['topic_attempts']) self.topic_correct = defaultdict(int, checkpoint['topic_correct']) self.memory = checkpoint['memory'] self.grad_step = checkpoint.get('grad_step', 0) print(f"✅ Loaded checkpoint from {path}")