"""
UNIFIED CONTINUAL LEARNER
Combines: EWC + Experience Replay + Progressive Networks
Goal: <20% forgetting across multiple tasks
"""
import torch
import torch.nn as nn
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))

from continual_learner import ContinualLearner, create_task_data
from experience_replay import ExperienceBuffer, ReplayContinualLearner
from progressive_nets import ProgressiveNeuralNetwork, train_progressive_network
import matplotlib.pyplot as plt
import numpy as np

class UnifiedLearner:
    """
    Unified continual learning system with multiple strategies
    """
    
    def __init__(self, input_size=10, hidden_sizes=[64, 32], output_size=2):
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        
        # Different learning strategies
        self.strategies = {}
        
    def create_strategy(self, name: str):
        """Create a learning strategy"""
        if name == "baseline":
            # Just a regular neural network (no continual learning)
            model = ContinualLearner(self.input_size, self.hidden_sizes, self.output_size)
            model.ewc_lambda = 0  # Disable EWC
            return model
        
        elif name == "ewc":
            # EWC only
            model = ContinualLearner(self.input_size, self.hidden_sizes, self.output_size)
            model.ewc_lambda = 1000
            return model
        
        elif name == "ewc+replay":
            # EWC + Experience Replay
            model = ContinualLearner(self.input_size, self.hidden_sizes, self.output_size)
            model.ewc_lambda = 1000
            return ReplayContinualLearner(model, replay_ratio=0.3)
        
        elif name == "progressive":
            # Progressive Neural Networks
            model = ProgressiveNeuralNetwork(self.input_size, self.hidden_sizes, self.output_size)
            return model
        
        else:
            raise ValueError(f"Unknown strategy: {name}")
    
    def benchmark_all_strategies(self, n_tasks=5, epochs=15):
        """
        Benchmark all strategies on same task sequence
        """
        print("\n" + "="*70)
        print("COMPREHENSIVE CONTINUAL LEARNING BENCHMARK")
        print(f"Tasks: {n_tasks} | Epochs per task: {epochs}")
        print("="*70)
        
        strategies = ["baseline", "ewc", "ewc+replay", "progressive"]
        results = {}
        
        for strategy_name in strategies:
            print(f"\n{'='*70}")
            print(f"TESTING STRATEGY: {strategy_name.upper()}")
            print(f"{'='*70}")
            
            results[strategy_name] = self.run_strategy(strategy_name, n_tasks, epochs)
        
        # Visualize results
        self.visualize_results(results, n_tasks)
        
        return results
    
    def run_strategy(self, strategy_name: str, n_tasks: int, epochs: int):
        """Run a single strategy through all tasks"""
        model = self.create_strategy(strategy_name)
        
        test_loaders = {}
        task_performance = []
        
        for task_id in range(n_tasks):
            # Create task data
            X_train, y_train = create_task_data(task_id, n_samples=1000)
            X_test, y_test = create_task_data(task_id, n_samples=200)
            
            train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
            test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
            
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
            
            test_loaders[task_id] = test_loader
            
            # Train based on strategy type
            if strategy_name == "progressive":
                model.add_column()
                if task_id > 0:
                    model.freeze_previous_columns()
                train_progressive_network(model, train_loader, test_loaders, task_id, epochs)
            
            elif strategy_name == "ewc+replay":
                model.train_with_replay(train_loader, task_id, epochs)
            
            else:
                # Baseline or EWC
                if hasattr(model, 'train_task'):
                    # Create global test_loaders_history for evaluation
                    import continual_learner
                    continual_learner.test_loaders_history = test_loaders
                    model.train_task(train_loader, test_loader, task_id, epochs)
            
            # Evaluate on all tasks
            task_accuracies = self.evaluate_all_tasks(model, test_loaders, task_id, strategy_name)
            task_performance.append(task_accuracies)
        
        # Compute final metrics
        final_metrics = self.compute_metrics(task_performance)
        
        print(f"\n{'='*70}")
        print(f"FINAL RESULTS - {strategy_name.upper()}")
        print(f"{'='*70}")
        print(f"Average Accuracy: {final_metrics['avg_accuracy']:.2f}%")
        print(f"Average Forgetting: {final_metrics['avg_forgetting']:.2f}%")
        print(f"Final Task Accuracy: {final_metrics['final_task_acc']:.2f}%")
        print(f"{'='*70}\n")
        
        return {
            'performance': task_performance,
            'metrics': final_metrics
        }
    
    def evaluate_all_tasks(self, model, test_loaders, current_task, strategy_name):
        """Evaluate model on all tasks learned so far"""
        # FIX: Handle wrapper classes
        if hasattr(model, 'model'):
            actual_model = model.model  # For ReplayContinualLearner
        else:
            actual_model = model
        
        actual_model.eval()
        task_accuracies = {}
        
        with torch.no_grad():
            for task_id in range(current_task + 1):
                if task_id in test_loaders:
                    correct = 0
                    total = 0
                    
                    for inputs, targets in test_loaders[task_id]:
                        if strategy_name == "progressive":
                            outputs = actual_model(inputs, task_id=task_id)
                        else:
                            outputs = actual_model(inputs)
                        
                        _, predicted = outputs.max(1)
                        total += targets.size(0)
                        correct += predicted.eq(targets).sum().item()
                    
                    accuracy = 100. * correct / total
                    task_accuracies[task_id] = accuracy
        
        return task_accuracies
    
    def compute_metrics(self, task_performance):
        """Compute forgetting and accuracy metrics"""
        if len(task_performance) == 0:
            return {'avg_accuracy': 0, 'avg_forgetting': 0, 'final_task_acc': 0}
        
        # Average accuracy across all tasks at the end
        final_accuracies = task_performance[-1]
        avg_accuracy = sum(final_accuracies.values()) / len(final_accuracies)
        
        # Compute forgetting for each task
        forgetting_list = []
        for task_id in range(len(task_performance) - 1):
            initial_acc = task_performance[task_id][task_id]  # Accuracy right after learning
            final_acc = final_accuracies[task_id]  # Accuracy at the end
            forgetting = initial_acc - final_acc
            forgetting_list.append(forgetting)
        
        avg_forgetting = sum(forgetting_list) / len(forgetting_list) if forgetting_list else 0
        
        # Final task accuracy (how well it learned the last task)
        final_task_acc = final_accuracies[len(task_performance) - 1]
        
        return {
            'avg_accuracy': avg_accuracy,
            'avg_forgetting': avg_forgetting,
            'final_task_acc': final_task_acc
        }
    
    def visualize_results(self, results, n_tasks):
        """Create visualization comparing all strategies"""
        strategies = list(results.keys())
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Continual Learning: Strategy Comparison', fontsize=16)
        
        # Plot 1: Average Accuracy
        ax1 = axes[0, 0]
        accuracies = [results[s]['metrics']['avg_accuracy'] for s in strategies]
        bars = ax1.bar(strategies, accuracies, color=['red', 'orange', 'blue', 'green'])
        ax1.set_ylabel('Average Accuracy (%)')
        ax1.set_title('Average Accuracy Across All Tasks')
        ax1.axhline(y=80, color='gray', linestyle='--', alpha=0.5)
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}%', ha='center', va='bottom')
        
        # Plot 2: Forgetting
        ax2 = axes[0, 1]
        forgetting = [results[s]['metrics']['avg_forgetting'] for s in strategies]
        bars = ax2.bar(strategies, forgetting, color=['red', 'orange', 'blue', 'green'])
        ax2.set_ylabel('Average Forgetting (%)')
        ax2.set_title('Average Forgetting (Lower is Better)')
        ax2.axhline(y=20, color='green', linestyle='--', alpha=0.5, label='Goal: <20%')
        ax2.legend()
        for bar in bars:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}%', ha='center', va='bottom')
        
        # Plot 3: Task-by-task accuracy for best strategy
        ax3 = axes[1, 0]
        for strategy in strategies:
            final_perf = results[strategy]['performance'][-1]
            tasks = sorted(final_perf.keys())
            accs = [final_perf[t] for t in tasks]
            ax3.plot(tasks, accs, marker='o', label=strategy)
        ax3.set_xlabel('Task ID')
        ax3.set_ylabel('Accuracy (%)')
        ax3.set_title('Final Accuracy on Each Task')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Summary table
        ax4 = axes[1, 1]
        ax4.axis('tight')
        ax4.axis('off')
        
        table_data = []
        table_data.append(['Strategy', 'Avg Acc', 'Forgetting', 'Goal'])
        for strategy in strategies:
            metrics = results[strategy]['metrics']
            goal_met = '✅' if metrics['avg_forgetting'] < 20 else '❌'
            table_data.append([
                strategy,
                f"{metrics['avg_accuracy']:.1f}%",
                f"{metrics['avg_forgetting']:.1f}%",
                goal_met
            ])
        
        table = ax4.table(cellText=table_data, cellLoc='center', loc='center',
                         colWidths=[0.3, 0.2, 0.2, 0.1])
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2)
        
        # Style header row
        for i in range(4):
            table[(0, i)].set_facecolor('#40466e')
            table[(0, i)].set_text_props(weight='bold', color='white')
        
        plt.tight_layout()
        plt.savefig('real_capabilities/learning/continual/benchmark_results.png', dpi=150)
        print(f"\n📊 Results saved to: real_capabilities/learning/continual/benchmark_results.png")
        
        return fig

if __name__ == "__main__":
    print("\n" + "="*70)
    print("UNIFIED CONTINUAL LEARNING SYSTEM")
    print("Testing: Baseline | EWC | EWC+Replay | Progressive Networks")
    print("="*70)
    
    learner = UnifiedLearner()
    results = learner.benchmark_all_strategies(n_tasks=5, epochs=15)
    
    print("\n" + "="*70)
    print("BENCHMARK COMPLETE!")
    print("="*70)
    print("\n🎯 GOAL: <20% Forgetting")
    
    for strategy, data in results.items():
        forgetting = data['metrics']['avg_forgetting']
        goal_met = "✅ ACHIEVED" if forgetting < 20 else "❌ NOT MET"
        print(f"{strategy:15} - Forgetting: {forgetting:5.2f}% - {goal_met}")
    
    print("\n✅ Week 2 Complete! Real continual learning system operational.")
