#!/usr/bin/env python3
"""
Using learn2learn library - proper MAML implementation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import learn2learn as l2l

# Seeds
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 5)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def create_task():
    w = np.random.randn(10, 5) * 2
    b = np.random.randn(5)
    X = np.random.randn(100, 10)
    Y = (X @ w + b).argmax(axis=1)
    X = torch.FloatTensor(X).to(device)
    Y = torch.LongTensor(Y).to(device)
    return X, Y

print("="*70)
print("META-LEARNING WITH LEARN2LEARN LIBRARY")
print("="*70)

# Create model and wrap with MAML
model = Net().to(device)
maml = l2l.algorithms.MAML(model, lr=0.01, first_order=True)
meta_opt = torch.optim.Adam(maml.parameters(), lr=0.001)

print("\nMeta-training (200 epochs)...")

for epoch in range(200):
    meta_loss = 0
    
    for _ in range(5):  # 5 tasks per epoch
        # Clone model for this task
        learner = maml.clone()
        
        # Create task
        data, labels = create_task()
        support_x, support_y = data[:70], labels[:70]
        query_x, query_y = data[70:], labels[70:]
        
        # Inner loop adaptation (10 steps)
        for _ in range(10):
            pred = learner(support_x)
            loss = F.cross_entropy(pred, support_y)
            learner.adapt(loss)
        
        # Compute meta-loss on query set
        pred = learner(query_x)
        loss = F.cross_entropy(pred, query_y)
        meta_loss += loss
    
    # Meta-optimization
    meta_opt.zero_grad()
    meta_loss.backward()
    meta_opt.step()
    
    if epoch % 40 == 0:
        print(f"Epoch {epoch}: Loss={meta_loss.item():.4f}")

print("✅ Training complete\n")

# Test
print("="*70)
print("TESTING")
print("="*70)

accs = []
for i in range(20):
    learner = maml.clone()
    
    data, labels = create_task()
    support_x, support_y = data[:70], labels[:70]
    query_x, query_y = data[70:], labels[70:]
    
    # Adapt
    for _ in range(25):
        pred = learner(support_x)
        loss = F.cross_entropy(pred, support_y)
        learner.adapt(loss)
    
    # Test
    with torch.no_grad():
        pred = learner(query_x)
        acc = (pred.argmax(1) == query_y).float().mean().item()
        accs.append(acc)
        status = "✅" if acc >= 0.8 else "⚠️" if acc >= 0.6 else "❌"
        print(f"{status} Task {i+1}: {acc*100:.1f}%")

avg = np.mean(accs)
print(f"\n{'='*70}")
print(f"Average: {avg*100:.1f}%")
print(f"{'='*70}")

if avg >= 0.85:
    print("🎉 EXCELLENT!")
elif avg >= 0.75:
    print("✅ Good!")
elif avg >= 0.60:
    print("⚠️ Needs work")
else:
    print("❌ Still broken")
