#!/usr/bin/env python3
"""
Meta-Learning V5 - Optimized hyperparameters
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MetaNetV5(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 2)
        
        # Better initialization
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        return self.fc3(x)

class MetaLearnerV5:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Initializing meta-learner V5...")
        self.model = MetaNetV5().to(self.device)
        print("✅ Meta-learning V5 ready!")
    
    def create_task(self, n_samples=100):
        """Create well-separated task"""
        weight = torch.randn(10).to(self.device)
        weight = weight / weight.norm()
        bias = torch.randn(1).to(self.device) * 0.5
        
        data = torch.randn(n_samples, 10).to(self.device)
        scores = (data @ weight + bias).squeeze()
        labels = (scores > 0).long()
        
        # Ensure balanced classes
        if labels.sum() < 0.3 * n_samples or labels.sum() > 0.7 * n_samples:
            return self.create_task(n_samples)
        
        return data, labels
    
    def meta_train(self, n_epochs=300):
        meta_opt = torch.optim.Adam(self.model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(meta_opt, T_max=n_epochs)
        
        print(f"\nMeta-training for {n_epochs} epochs...")
        
        for epoch in range(n_epochs):
            epoch_loss = 0
            
            for _ in range(15):
                data, labels = self.create_task(100)
                
                # Support: 70, Query: 30
                support_x, support_y = data[:70], labels[:70]
                query_x, query_y = data[70:], labels[70:]
                
                # Clone and adapt
                adapted = MetaNetV5().to(self.device)
                adapted.load_state_dict(self.model.state_dict())
                inner_opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
                
                # Inner loop: 10 steps
                for _ in range(10):
                    pred = adapted(support_x)
                    loss = F.cross_entropy(pred, support_y)
                    inner_opt.zero_grad()
                    loss.backward()
                    inner_opt.step()
                
                # Meta-loss
                pred_query = adapted(query_x)
                meta_loss = F.cross_entropy(pred_query, query_y)
                epoch_loss += meta_loss
            
            meta_opt.zero_grad()
            epoch_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            meta_opt.step()
            scheduler.step()
            
            if (epoch + 1) % 50 == 0:
                print(f"Epoch {epoch+1}: Loss: {epoch_loss.item():.4f}")
        
        print("✅ Meta-training complete")
    
    def test_adaptation(self, n_tests=10):
        accs = []
        
        for _ in range(n_tests):
            data, labels = self.create_task(100)
            support_x, support_y = data[:70], labels[:70]
            query_x, query_y = data[70:], labels[70:]
            
            # Adapt with MORE steps
            adapted = MetaNetV5().to(self.device)
            adapted.load_state_dict(self.model.state_dict())
            opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
            
            for _ in range(20):  # More adaptation steps
                pred = adapted(support_x)
                loss = F.cross_entropy(pred, support_y)
                opt.zero_grad()
                loss.backward()
                opt.step()
            
            # Test
            with torch.no_grad():
                pred = adapted(query_x)
                acc = (pred.argmax(dim=1) == query_y).float().mean().item()
                accs.append(acc)
        
        return accs

def test_meta_v5():
    print("\n" + "="*70)
    print("TESTING META-LEARNING V5")
    print("="*70)
    
    ml = MetaLearnerV5()
    ml.meta_train(n_epochs=300)
    
    print("\nTesting on 10 new tasks...")
    accs = ml.test_adaptation(n_tests=10)
    
    for i, acc in enumerate(accs, 1):
        print(f"Task {i}: {acc*100:.1f}%")
    
    avg = np.mean(accs)
    std = np.std(accs)
    print(f"\nAverage: {avg*100:.1f}% (±{std*100:.1f}%)")
    
    if avg >= 0.80:
        print("✅ EXCELLENT - Meta-learning working!")
        return True
    elif avg >= 0.75:
        print("✅ GOOD - Strong adaptation!")
        return True
    elif avg >= 0.70:
        print("✅ DECENT - Working meta-learning!")
        return True
    else:
        print("⚠️ Partial success")
        return False

def main():
    if test_meta_v5():
        print("\n✅ CAPABILITY #13 COMPLETE (V5)")

if __name__ == "__main__":
    main()
