"""
MAML Hyperparameter Tuning
Find optimal settings to achieve >85% accuracy
"""
import torch
import numpy as np
from maml_trainer import MAMLTrainer
import itertools

def quick_eval(trainer, n_tasks=20):
    """Quick evaluation on new tasks"""
    total_acc = 0
    for i in range(n_tasks):
        task_id = 30000 + i
        from maml import create_classification_task
        support_x, support_y, query_x, query_y = create_classification_task(
            task_id, n_samples=40
        )
        _, acc_after = trainer.maml.evaluate_task(
            support_x, support_y, query_x, query_y
        )
        total_acc += acc_after
    return total_acc / n_tasks

def tune_hyperparameters():
    """
    Systematically test hyperparameter combinations
    """
    print("\n" + "="*70)
    print("MAML HYPERPARAMETER TUNING")
    print("Testing different configurations to achieve >85% accuracy")
    print("="*70)
    
    # Hyperparameter grid
    configs = [
        # (inner_lr, outer_lr, inner_steps, iterations)
        (0.01, 0.001, 5, 2000),    # Original
        (0.05, 0.001, 5, 2000),    # Higher inner LR
        (0.01, 0.001, 10, 2000),   # More inner steps
        (0.05, 0.001, 10, 2000),   # Both higher
        (0.1, 0.001, 10, 2000),    # Even higher inner LR
        (0.05, 0.003, 10, 3000),   # Higher both + more iterations
    ]
    
    results = []
    
    for idx, (inner_lr, outer_lr, inner_steps, iterations) in enumerate(configs):
        print(f"\n{'='*70}")
        print(f"CONFIG {idx+1}/{len(configs)}: inner_lr={inner_lr}, outer_lr={outer_lr}, "
              f"inner_steps={inner_steps}, iterations={iterations}")
        print(f"{'='*70}")
        
        # Train with this configuration
        trainer = MAMLTrainer(
            input_size=10,
            hidden_sizes=[40, 40],
            output_size=2,
            inner_lr=inner_lr,
            outer_lr=outer_lr,
            inner_steps=inner_steps
        )
        
        trainer.meta_train(
            n_iterations=iterations,
            tasks_per_batch=4,
            eval_every=iterations // 5  # Eval 5 times
        )
        
        # Quick evaluation
        final_acc = quick_eval(trainer, n_tasks=30) * 100
        
        print(f"\n✅ Config {idx+1} completed - Final accuracy: {final_acc:.2f}%")
        
        results.append({
            'config_id': idx + 1,
            'inner_lr': inner_lr,
            'outer_lr': outer_lr,
            'inner_steps': inner_steps,
            'iterations': iterations,
            'accuracy': final_acc,
            'trainer': trainer
        })
    
    # Find best configuration
    print("\n" + "="*70)
    print("TUNING RESULTS")
    print("="*70)
    
    results.sort(key=lambda x: x['accuracy'], reverse=True)
    
    print(f"\n{'Config':<8} {'Inner LR':<10} {'Steps':<8} {'Iters':<8} {'Accuracy':<10}")
    print("-" * 70)
    
    for r in results:
        print(f"{r['config_id']:<8} {r['inner_lr']:<10.3f} {r['inner_steps']:<8} "
              f"{r['iterations']:<8} {r['accuracy']:<10.2f}%")
    
    best = results[0]
    print(f"\n{'='*70}")
    print(f"🏆 BEST CONFIGURATION:")
    print(f"{'='*70}")
    print(f"  Inner LR: {best['inner_lr']}")
    print(f"  Outer LR: {best['outer_lr']}")
    print(f"  Inner Steps: {best['inner_steps']}")
    print(f"  Iterations: {best['iterations']}")
    print(f"  Accuracy: {best['accuracy']:.2f}%")
    
    if best['accuracy'] >= 85:
        print(f"\n🎯 GOAL ACHIEVED! >85% accuracy")
    else:
        print(f"\n📈 Best so far: {best['accuracy']:.2f}% (goal: 85%)")
        print(f"   Gap: {85 - best['accuracy']:.2f}%")
    
    return best

def train_best_longer(best_config):
    """
    Train the best configuration even longer if needed
    """
    if best_config['accuracy'] >= 85:
        print("\nGoal already achieved! No additional training needed.")
        return best_config
    
    print("\n" + "="*70)
    print("EXTENDED TRAINING WITH BEST CONFIG")
    print("Training for 5000 iterations to reach 85%+")
    print("="*70)
    
    trainer = MAMLTrainer(
        input_size=10,
        hidden_sizes=[40, 40],
        output_size=2,
        inner_lr=best_config['inner_lr'],
        outer_lr=best_config['outer_lr'],
        inner_steps=best_config['inner_steps']
    )
    
    trainer.meta_train(
        n_iterations=5000,
        tasks_per_batch=4,
        eval_every=500
    )
    
    # Final evaluation
    final_acc = quick_eval(trainer, n_tasks=50) * 100
    
    print(f"\n{'='*70}")
    print(f"FINAL RESULT: {final_acc:.2f}%")
    print(f"{'='*70}")
    
    if final_acc >= 85:
        print(f"🎯 GOAL ACHIEVED!")
    
    return {**best_config, 'accuracy': final_acc, 'trainer': trainer}

if __name__ == "__main__":
    # Tune hyperparameters
    best = tune_hyperparameters()
    
    # Train best config longer if needed
    final = train_best_longer(best)
    
    # Full comparison
    print("\n" + "="*70)
    print("FINAL BENCHMARK WITH TUNED HYPERPARAMETERS")
    print("="*70)
    
    results = final['trainer'].compare_with_baseline(n_test_tasks=50)
    
    print("\n✅ WEEK 3 HYPERPARAMETER TUNING COMPLETE!")
