"""
MAML Meta-Training - Learn the optimal initialization
"""
import torch
import torch.nn as nn
import numpy as np
from maml import MAML, MAMLModel, create_classification_task
import matplotlib.pyplot as plt

class MAMLTrainer:
    """Train MAML on a distribution of tasks"""
    
    def __init__(
        self,
        input_size=10,
        hidden_sizes=[40, 40],
        output_size=2,
        inner_lr=0.01,
        outer_lr=0.001,
        inner_steps=5
    ):
        self.model = MAMLModel(input_size, hidden_sizes, output_size)
        self.maml = MAML(
            self.model,
            inner_lr=inner_lr,
            outer_lr=outer_lr,
            inner_steps=inner_steps
        )
        
        self.train_history = []
        
    def generate_task_batch(self, n_tasks, task_offset=0):
        """Generate batch of tasks for meta-training"""
        task_batch = []
        
        for i in range(n_tasks):
            task_id = task_offset + i
            support_x, support_y, query_x, query_y = create_classification_task(
                task_id, n_samples=40
            )
            task_batch.append((support_x, support_y, query_x, query_y))
        
        return task_batch
    
    def meta_train(self, n_iterations=1000, tasks_per_batch=4, eval_every=100):
        """
        Meta-train MAML on distribution of tasks
        """
        print("\n" + "="*70)
        print("META-TRAINING MAML")
        print(f"Iterations: {n_iterations} | Tasks per batch: {tasks_per_batch}")
        print("="*70)
        
        for iteration in range(n_iterations):
            # Generate batch of tasks
            task_offset = iteration * tasks_per_batch
            task_batch = self.generate_task_batch(tasks_per_batch, task_offset)
            
            # Meta-training step
            loss, acc = self.maml.meta_train_step(task_batch)
            
            self.train_history.append({
                'iteration': iteration,
                'loss': loss,
                'accuracy': acc
            })
            
            # Periodic evaluation
            if (iteration + 1) % eval_every == 0:
                print(f"Iteration {iteration+1}/{n_iterations} - "
                      f"Loss: {loss:.4f} - Acc: {acc*100:.2f}%")
                
                # Evaluate on new tasks
                eval_acc = self.evaluate_on_new_tasks(n_tasks=20)
                print(f"  Eval on new tasks: {eval_acc*100:.2f}%")
        
        print("\n✅ Meta-training complete!")
        
    def evaluate_on_new_tasks(self, n_tasks=20):
        """Evaluate on completely new tasks (not seen during training)"""
        total_acc = 0
        
        for i in range(n_tasks):
            # Use high task_id to ensure these are new tasks
            task_id = 10000 + i
            support_x, support_y, query_x, query_y = create_classification_task(
                task_id, n_samples=40
            )
            
            _, acc_after = self.maml.evaluate_task(
                support_x, support_y, query_x, query_y
            )
            total_acc += acc_after
        
        return total_acc / n_tasks
    
    def compare_with_baseline(self, n_test_tasks=50):
        """
        Compare MAML vs baseline (no meta-learning)
        """
        print("\n" + "="*70)
        print("COMPARING MAML VS BASELINE")
        print("="*70)
        
        maml_before = []
        maml_after = []
        baseline_after = []
        
        for i in range(n_test_tasks):
            task_id = 20000 + i
            support_x, support_y, query_x, query_y = create_classification_task(
                task_id, n_samples=40
            )
            
            # MAML performance
            acc_before, acc_after = self.maml.evaluate_task(
                support_x, support_y, query_x, query_y
            )
            maml_before.append(acc_before)
            maml_after.append(acc_after)
            
            # Baseline: random init + fine-tune
            baseline_model = MAMLModel(10, [40, 40], 2)
            baseline_maml = MAML(baseline_model, inner_lr=0.01, inner_steps=5)
            _, baseline_acc = baseline_maml.evaluate_task(
                support_x, support_y, query_x, query_y
            )
            baseline_after.append(baseline_acc)
        
        # Compute statistics
        maml_before_avg = np.mean(maml_before) * 100
        maml_after_avg = np.mean(maml_after) * 100
        baseline_after_avg = np.mean(baseline_after) * 100
        
        improvement = maml_after_avg - maml_before_avg
        vs_baseline = maml_after_avg - baseline_after_avg
        
        print(f"\nResults on {n_test_tasks} new tasks:")
        print(f"{'='*70}")
        print(f"MAML (before adaptation):  {maml_before_avg:.2f}%")
        print(f"MAML (after 5 steps):      {maml_after_avg:.2f}% (+{improvement:.2f}%)")
        print(f"Baseline (after 5 steps):  {baseline_after_avg:.2f}%")
        print(f"{'='*70}")
        print(f"MAML vs Baseline:          +{vs_baseline:.2f}%")
        
        # Check if goal achieved
        if maml_after_avg >= 85 and improvement >= 30:
            print(f"\n🎯 GOAL ACHIEVED!")
            print(f"   ✅ >85% accuracy after 5 steps")
            print(f"   ✅ >30% improvement from adaptation")
        else:
            print(f"\n⚠️  Goal not yet met")
            print(f"   Need: >85% accuracy, >30% improvement")
        
        return {
            'maml_before': maml_before_avg,
            'maml_after': maml_after_avg,
            'baseline': baseline_after_avg,
            'improvement': improvement
        }
    
    def visualize_training(self):
        """Plot training progress"""
        iterations = [h['iteration'] for h in self.train_history]
        accuracies = [h['accuracy'] * 100 for h in self.train_history]
        losses = [h['loss'] for h in self.train_history]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # Accuracy plot
        ax1.plot(iterations, accuracies, color='blue', linewidth=2)
        ax1.set_xlabel('Meta-Training Iteration')
        ax1.set_ylabel('Query Set Accuracy (%)')
        ax1.set_title('MAML: Meta-Training Progress (Accuracy)')
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=85, color='green', linestyle='--', alpha=0.5, label='Goal: 85%')
        ax1.legend()
        
        # Loss plot
        ax2.plot(iterations, losses, color='red', linewidth=2)
        ax2.set_xlabel('Meta-Training Iteration')
        ax2.set_ylabel('Query Set Loss')
        ax2.set_title('MAML: Meta-Training Progress (Loss)')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('real_capabilities/learning/meta/maml_training.png', dpi=150)
        print(f"\n📊 Training plot saved to: real_capabilities/learning/meta/maml_training.png")


if __name__ == "__main__":
    print("\n" + "="*70)
    print("MAML META-LEARNING: FULL TRAINING & EVALUATION")
    print("="*70)
    
    # Create trainer
    trainer = MAMLTrainer(
        input_size=10,
        hidden_sizes=[40, 40],
        output_size=2,
        inner_lr=0.01,
        outer_lr=0.001,
        inner_steps=5
    )
    
    # Meta-train
    trainer.meta_train(n_iterations=2000, tasks_per_batch=4, eval_every=200)
    
    # Visualize training
    trainer.visualize_training()
    
    # Compare with baseline
    results = trainer.compare_with_baseline(n_test_tasks=50)
    
    print("\n" + "="*70)
    print("WEEK 3 COMPLETE!")
    print("="*70)
    print("\nMeta-Learning Capability: ✅")
    print(f"Fast adaptation in 5 steps: {results['maml_after']:.1f}%")
    print(f"Improvement: {results['improvement']:.1f}%")
    print("\nTotal Capabilities: 3/30 (10%)")
