#!/usr/bin/env python3
"""
META-LEARNING - PERFECT VERSION
Target: 95%+ average accuracy
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

class MetaNetPerfect(nn.Module):
    """Enhanced meta-learner with better architecture"""
    def __init__(self):
        super().__init__()
        # Deeper, more expressive network
        self.net = nn.Sequential(
            nn.Linear(10, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
        # Initialize weights properly
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        return self.net(x)

class MetaLearnerPerfect:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = MetaNetPerfect().to(self.device)
        print(f"🔥 Using device: {self.device}")
    
    def create_excellent_task(self):
        """Create high-quality, diverse tasks"""
        # More diverse task types
        task_type = np.random.choice(['linear', 'quadratic', 'sinusoidal', 'exponential'])
        
        if task_type == 'linear':
            w = np.random.randn(10, 5) * 2
            b = np.random.randn(5) * 2
            X = np.random.randn(150, 10)
            Y = (X @ w + b).argmax(axis=1)
        elif task_type == 'quadratic':
            X = np.random.randn(150, 10)
            Y = ((X**2).sum(axis=1) % 5).astype(int)
        elif task_type == 'sinusoidal':
            X = np.random.randn(150, 10)
            Y = ((np.sin(X.sum(axis=1)) * 2 + 2.5).astype(int)) % 5
        else:  # exponential
            X = np.random.randn(150, 10)
            Y = ((np.exp(X[:, 0] / 5) * 2).astype(int)) % 5
        
        X = torch.FloatTensor(X).to(self.device)
        Y = torch.LongTensor(Y).to(self.device)
        return X, Y
    
    def meta_train(self, n_epochs=800, n_tasks_per_epoch=10):
        """Train with more epochs and better optimization"""
        print(f"\n{'='*70}")
        print(f"META-TRAINING: {n_epochs} epochs, {n_tasks_per_epoch} tasks/epoch")
        print(f"{'='*70}")
        
        # Better optimizer with cosine annealing
        meta_opt = torch.optim.AdamW(
            self.model.parameters(), 
            lr=0.001,
            weight_decay=0.01
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            meta_opt, 
            T_max=n_epochs,
            eta_min=0.0001
        )
        
        for epoch in range(n_epochs):
            epoch_loss = 0
            
            for _ in range(n_tasks_per_epoch):
                # Create task and split
                data, labels = self.create_excellent_task()
                support_x, support_y = data[:100], labels[:100]  # More support examples
                query_x, query_y = data[100:], labels[100:]
                
                # Clone model for inner loop
                adapted = deepcopy(self.model)
                inner_opt = torch.optim.SGD(adapted.parameters(), lr=0.02, momentum=0.9)
                
                # Inner adaptation loop - more steps
                for _ in range(15):
                    pred = adapted(support_x)
                    loss = F.cross_entropy(pred, support_y)
                    inner_opt.zero_grad()
                    loss.backward()
                    inner_opt.step()
                
                # Meta-loss on query set
                pred_query = adapted(query_x)
                meta_loss = F.cross_entropy(pred_query, query_y)
                epoch_loss += meta_loss
            
            # Meta-optimization
            meta_opt.zero_grad()
            epoch_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            meta_opt.step()
            scheduler.step()
            
            # Progress reporting
            if (epoch + 1) % 100 == 0:
                avg_loss = epoch_loss.item() / n_tasks_per_epoch
                lr = scheduler.get_last_lr()[0]
                print(f"Epoch {epoch+1}/{n_epochs}: Loss: {avg_loss:.4f}, LR: {lr:.6f}")
        
        print("✅ Meta-training complete")
    
    def test_adaptation(self, n_tests=20):
        """Test on more tasks with better adaptation"""
        print(f"\nTesting on {n_tests} new tasks...")
        accs = []
        
        for test_num in range(n_tests):
            data, labels = self.create_excellent_task()
            support_x, support_y = data[:100], labels[:100]
            query_x, query_y = data[100:], labels[100:]
            
            # Create adapted model
            adapted = MetaNetPerfect().to(self.device)
            adapted.load_state_dict(self.model.state_dict())
            
            # Fine-tune on support set with better optimization
            opt = torch.optim.AdamW(adapted.parameters(), lr=0.01, weight_decay=0.01)
            
            # More adaptation steps
            for _ in range(50):
                pred = adapted(support_x)
                loss = F.cross_entropy(pred, support_y)
                opt.zero_grad()
                loss.backward()
                opt.step()
            
            # Evaluate on query set
            with torch.no_grad():
                pred = adapted(query_x)
                acc = (pred.argmax(dim=1) == query_y).float().mean().item()
                accs.append(acc)
                
                # Show each task result
                status = "✅" if acc >= 0.90 else "⚠️" if acc >= 0.80 else "❌"
                print(f"  {status} Task {test_num+1}: {acc*100:.1f}%")
        
        return accs

def main():
    print("\n" + "="*70)
    print("🎯 META-LEARNING: PUSHING TO 100%")
    print("="*70)
    
    ml = MetaLearnerPerfect()
    ml.meta_train(n_epochs=800, n_tasks_per_epoch=10)
    
    print("\n" + "="*70)
    print("TESTING PERFORMANCE")
    print("="*70)
    
    accs = ml.test_adaptation(n_tests=20)
    
    avg = np.mean(accs)
    std = np.std(accs)
    min_acc = np.min(accs)
    max_acc = np.max(accs)
    
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    print(f"Average:  {avg*100:.1f}% (±{std*100:.1f}%)")
    print(f"Min:      {min_acc*100:.1f}%")
    print(f"Max:      {max_acc*100:.1f}%")
    print(f"Tasks ≥90%: {sum(1 for a in accs if a >= 0.90)}/20")
    print(f"Tasks ≥95%: {sum(1 for a in accs if a >= 0.95)}/20")
    
    if avg >= 0.95:
        print("\n🏆 PERFECT - 95%+ ACHIEVED!")
    elif avg >= 0.90:
        print("\n✅ EXCELLENT - 90%+ ACHIEVED!")
    elif avg >= 0.87:
        print("\n✅ STRONG - Improved from baseline!")
    else:
        print("\n⚠️ Need more training")
    
    print("="*70)

if __name__ == "__main__":
    main()
