#!/usr/bin/env python3
"""
Meta-Learning - Learn to Learn
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

class MAMLModel(nn.Module):
    """Model-Agnostic Meta-Learning"""
    def __init__(self, input_dim=10, hidden_dim=64, output_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class MetaLearner:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Initializing meta-learner...")
        self.model = MAMLModel().to(self.device)
        print("✅ Meta-learning ready!")
    
    def meta_train(self, tasks, epochs=20):
        """Train on multiple tasks to learn how to learn"""
        meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"\nMeta-training on {len(tasks)} tasks...")
        
        for epoch in range(epochs):
            meta_loss = 0
            
            for task_data, task_labels in tasks:
                # Inner loop: adapt to task
                task_model = MAMLModel().to(self.device)
                task_model.load_state_dict(self.model.state_dict())
                
                task_opt = torch.optim.SGD(task_model.parameters(), lr=0.01)
                
                # Few gradient steps on task
                for _ in range(5):
                    pred = task_model(task_data)
                    loss = F.cross_entropy(pred, task_labels)
                    task_opt.zero_grad()
                    loss.backward()
                    task_opt.step()
                
                # Meta-loss: how well does adapted model do?
                pred = task_model(task_data)
                loss = F.cross_entropy(pred, task_labels)
                meta_loss += loss
            
            # Meta-update
            meta_optimizer.zero_grad()
            meta_loss.backward()
            meta_optimizer.step()
            
            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}: Meta-loss: {meta_loss.item():.4f}")
        
        print("✅ Meta-training complete")
    
    def adapt_to_new_task(self, data, labels, steps=10):
        """Quickly adapt to new task"""
        adapted = MAMLModel().to(self.device)
        adapted.load_state_dict(self.model.state_dict())
        
        optimizer = torch.optim.SGD(adapted.parameters(), lr=0.01)
        
        for _ in range(steps):
            pred = adapted(data)
            loss = F.cross_entropy(pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        return adapted

def test_meta_learning():
    print("\n" + "="*70)
    print("TESTING META-LEARNING")
    print("="*70)
    
    ml = MetaLearner()
    
    # Generate synthetic tasks
    tasks = []
    for _ in range(10):
        data = torch.randn(20, 10).to(ml.device)
        labels = torch.randint(0, 2, (20,)).to(ml.device)
        tasks.append((data, labels))
    
    # Meta-train
    ml.meta_train(tasks, epochs=20)
    
    # Test adaptation
    print("\nTesting quick adaptation to new task...")
    test_data = torch.randn(20, 10).to(ml.device)
    test_labels = torch.randint(0, 2, (20,)).to(ml.device)
    
    adapted = ml.adapt_to_new_task(test_data[:10], test_labels[:10])
    
    with torch.no_grad():
        pred = adapted(test_data[10:])
        acc = (pred.argmax(dim=1) == test_labels[10:]).float().mean().item()
    
    print(f"Adaptation accuracy: {acc*100:.1f}%")
    
    if acc > 0.5:
        print("✅ Meta-learning working!")
        return True
    return False

def main():
    if test_meta_learning():
        print("\n✅ CAPABILITY #13 COMPLETE")

if __name__ == "__main__":
    main()
