#!/usr/bin/env python3
"""
Meta-Learning V3 - Structured tasks with clear patterns
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MetaModel(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )
    
    def forward(self, x):
        return self.net(x)

class MetaLearnerV3:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Initializing meta-learner V3...")
        self.model = MetaModel().to(self.device)
        print("✅ Meta-learning V3 ready!")
    
    def generate_structured_task(self, n_samples=40):
        """Generate task with clear pattern"""
        # Random linear separator
        w = torch.randn(10, 1).to(self.device)
        b = torch.randn(1).to(self.device)
        
        data = torch.randn(n_samples, 10).to(self.device)
        scores = (data @ w + b).squeeze()
        labels = (scores > 0).long()
        
        return data, labels
    
    def meta_train(self, n_tasks=30, epochs=100):
        meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        print(f"\nMeta-training on {n_tasks} structured tasks...")
        
        for epoch in range(epochs):
            meta_loss = 0
            
            for _ in range(n_tasks):
                # Generate structured task
                task_data, task_labels = self.generate_structured_task()
                
                # Clone for inner loop
                task_model = MetaModel().to(self.device)
                task_model.load_state_dict(self.model.state_dict())
                task_opt = torch.optim.SGD(task_model.parameters(), lr=0.01)
                
                # Inner loop: adapt
                for _ in range(10):
                    pred = task_model(task_data[:30])
                    loss = F.cross_entropy(pred, task_labels[:30])
                    task_opt.zero_grad()
                    loss.backward()
                    task_opt.step()
                
                # Meta-loss on remaining data
                pred = task_model(task_data[30:])
                loss = F.cross_entropy(pred, task_labels[30:])
                meta_loss += loss
            
            meta_optimizer.zero_grad()
            meta_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            meta_optimizer.step()
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}: Meta-loss: {meta_loss.item():.4f}")
        
        print("✅ Meta-training complete")
    
    def adapt_and_test(self):
        """Generate new task, adapt, and test"""
        # New task
        task_data, task_labels = self.generate_structured_task(50)
        
        # Adapt on support set
        adapted = MetaModel().to(self.device)
        adapted.load_state_dict(self.model.state_dict())
        opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
        
        for _ in range(20):
            pred = adapted(task_data[:25])
            loss = F.cross_entropy(pred, task_labels[:25])
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        # Test on query set
        with torch.no_grad():
            pred = adapted(task_data[25:])
            acc = (pred.argmax(dim=1) == task_labels[25:]).float().mean().item()
        
        return acc

def test_meta_v3():
    print("\n" + "="*70)
    print("TESTING META-LEARNING V3")
    print("="*70)
    
    ml = MetaLearnerV3()
    ml.meta_train(n_tasks=30, epochs=100)
    
    # Test multiple times
    print("\nTesting adaptation on 5 new tasks...")
    accs = []
    for i in range(5):
        acc = ml.adapt_and_test()
        accs.append(acc)
        print(f"Task {i+1}: {acc*100:.1f}%")
    
    avg_acc = np.mean(accs)
    print(f"\nAverage: {avg_acc*100:.1f}%")
    
    if avg_acc >= 0.8:
        print("✅ EXCELLENT - Meta-learning working!")
        return True
    elif avg_acc >= 0.7:
        print("✅ GOOD - Strong adaptation!")
        return True
    else:
        print("⚠️ Partial success")
        return False

def main():
    if test_meta_v3():
        print("\n✅ CAPABILITY #13 COMPLETE (V3)")

if __name__ == "__main__":
    main()
