"""
Experience Replay - Store and replay old examples to prevent forgetting
This is how DeepMind's DQN learned to play Atari games
"""
import torch
import numpy as np
from collections import deque
import random

class ExperienceBuffer:
    """Store experiences from all tasks and sample them during training"""
    
    def __init__(self, max_size: int = 10000):
        self.max_size = max_size  # FIX: Store as instance variable
        self.buffer = deque(maxlen=max_size)
        self.task_buffers = {}  # Separate buffer per task
        
    def add(self, task_id: int, inputs: torch.Tensor, targets: torch.Tensor):
        """Add experience to buffer"""
        for inp, targ in zip(inputs, targets):
            self.buffer.append({
                'task_id': task_id,
                'input': inp.clone(),
                'target': targ.clone()
            })
            
            # Also store in task-specific buffer
            if task_id not in self.task_buffers:
                self.task_buffers[task_id] = deque(maxlen=self.max_size // 10)  # FIX: Use self.max_size
            self.task_buffers[task_id].append({
                'input': inp.clone(),
                'target': targ.clone()
            })
    
    def sample(self, batch_size: int, task_id: int = None):
        """Sample batch of experiences"""
        if task_id is not None and task_id in self.task_buffers:
            # Sample from specific task
            buffer = self.task_buffers[task_id]
        else:
            # Sample from all tasks
            buffer = self.buffer
        
        if len(buffer) < batch_size:
            batch_size = len(buffer)
        
        experiences = random.sample(list(buffer), batch_size)
        
        inputs = torch.stack([exp['input'] for exp in experiences])
        targets = torch.stack([exp['target'] for exp in experiences])
        
        return inputs, targets
    
    def get_task_data(self, task_id: int):
        """Get all data for a specific task"""
        if task_id not in self.task_buffers:
            return None, None
        
        buffer = self.task_buffers[task_id]
        inputs = torch.stack([exp['input'] for exp in buffer])
        targets = torch.stack([exp['target'] for exp in buffer])
        
        return inputs, targets

class ReplayContinualLearner:
    """
    Continual learner with experience replay
    Combines EWC + Replay for best results
    """
    
    def __init__(self, model, replay_ratio: float = 0.3):
        self.model = model
        self.replay_buffer = ExperienceBuffer()
        self.replay_ratio = replay_ratio  # 30% replay, 70% new data
        
    def train_with_replay(self, train_loader, task_id: int, epochs: int = 20):
        """Train with replay of old experiences"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"\n{'='*60}")
        print(f"Training Task {task_id} WITH Experience Replay")
        print(f"Replay Ratio: {self.replay_ratio:.0%}")
        print(f"{'='*60}")
        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            correct = 0
            total = 0
            
            for inputs, targets in train_loader:
                optimizer.zero_grad()
                
                # Add current data to replay buffer
                self.replay_buffer.add(task_id, inputs, targets)
                
                # Mix current batch with replay
                if task_id > 0 and len(self.replay_buffer.buffer) > 0:
                    replay_size = int(len(inputs) * self.replay_ratio)
                    if replay_size > 0:
                        replay_inputs, replay_targets = self.replay_buffer.sample(replay_size)
                        
                        # Combine current + replay
                        inputs = torch.cat([inputs, replay_inputs])
                        targets = torch.cat([targets, replay_targets])
                
                # Forward pass
                outputs = self.model(inputs)
                
                # Task loss
                task_loss = torch.nn.functional.cross_entropy(outputs, targets)
                
                # EWC penalty
                ewc_penalty = self.model.ewc_loss() if self.model.tasks_learned > 0 else 0
                
                # Total loss
                loss = task_loss + ewc_penalty
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            
            accuracy = 100. * correct / total
            avg_loss = total_loss / len(train_loader)
            
            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} - Acc: {accuracy:.2f}%")
        
        # Compute Fisher information after training
        self.model.compute_fisher_information(train_loader, task_id)
        self.model.save_optimal_params(task_id)
        self.model.tasks_learned += 1
        
        print(f"✅ Task {task_id} training complete")
        print(f"   Buffer size: {len(self.replay_buffer.buffer)} experiences")
