#!/usr/bin/env python3
"""
Meta-Learning V2 - Better MAML implementation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ImprovedMAML(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=128, output_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class MetaLearnerV2:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Initializing improved meta-learner...")
        self.model = ImprovedMAML().to(self.device)
        print("✅ Meta-learning V2 ready!")
    
    def meta_train(self, tasks, epochs=50, inner_steps=10):
        meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"\nMeta-training on {len(tasks)} tasks...")
        
        for epoch in range(epochs):
            meta_loss = 0
            
            for task_data, task_labels in tasks:
                # Clone model for task adaptation
                task_model = ImprovedMAML().to(self.device)
                task_model.load_state_dict(self.model.state_dict())
                
                task_opt = torch.optim.SGD(task_model.parameters(), lr=0.01)
                
                # Inner loop: adapt to task
                for _ in range(inner_steps):
                    pred = task_model(task_data)
                    loss = F.cross_entropy(pred, task_labels)
                    task_opt.zero_grad()
                    loss.backward()
                    task_opt.step()
                
                # Meta-loss
                pred = task_model(task_data)
                loss = F.cross_entropy(pred, task_labels)
                meta_loss += loss
            
            # Meta-update
            meta_optimizer.zero_grad()
            meta_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            meta_optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}: Meta-loss: {meta_loss.item():.4f}")
        
        print("✅ Meta-training complete")
    
    def adapt_to_new_task(self, data, labels, steps=20):
        adapted = ImprovedMAML().to(self.device)
        adapted.load_state_dict(self.model.state_dict())
        
        optimizer = torch.optim.SGD(adapted.parameters(), lr=0.01)
        
        for _ in range(steps):
            pred = adapted(data)
            loss = F.cross_entropy(pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return adapted

def test_meta_v2():
    print("\n" + "="*70)
    print("TESTING META-LEARNING V2")
    print("="*70)
    
    ml = MetaLearnerV2()
    
    # Generate more diverse tasks
    tasks = []
    for _ in range(20):
        data = torch.randn(30, 10).to(ml.device)
        labels = torch.randint(0, 2, (30,)).to(ml.device)
        tasks.append((data, labels))
    
    ml.meta_train(tasks, epochs=50, inner_steps=10)
    
    # Test adaptation
    print("\nTesting quick adaptation...")
    test_data = torch.randn(30, 10).to(ml.device)
    test_labels = torch.randint(0, 2, (30,)).to(ml.device)
    
    adapted = ml.adapt_to_new_task(test_data[:15], test_labels[:15], steps=20)
    
    with torch.no_grad():
        pred = adapted(test_data[15:])
        acc = (pred.argmax(dim=1) == test_labels[15:]).float().mean().item()
    
    print(f"Adaptation accuracy: {acc*100:.1f}%")
    
    if acc >= 0.75:
        print("✅ EXCELLENT - Meta-learning working!")
        return True
    elif acc >= 0.65:
        print("✅ GOOD - Strong meta-learning!")
        return True
    else:
        print("⚠️ Needs improvement")
        return False

def main():
    if test_meta_v2():
        print("\n✅ CAPABILITY #13 COMPLETE (IMPROVED)")

if __name__ == "__main__":
    main()
