#!/usr/bin/env python3
"""
PROPER MAML IMPLEMENTATION
Using first-order approximation (no deepcopy)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Seeds
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.empty_cache()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

class MetaNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 5)
    
    def forward(self, x, params=None):
        if params is None:
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
        else:
            x = F.linear(x, params['fc1.weight'], params['fc1.bias'])
            x = F.relu(x)
            x = F.linear(x, params['fc2.weight'], params['fc2.bias'])
        return x

def create_task():
    w = np.random.randn(10, 5) * 2
    b = np.random.randn(5)
    X = np.random.randn(100, 10)
    Y = (X @ w + b).argmax(axis=1)
    X = torch.FloatTensor(X).to(device)
    Y = torch.LongTensor(Y).to(device)
    return X, Y

def inner_loop_manual(model, support_x, support_y, steps=10, lr=0.01):
    """Manual inner loop with gradient descent"""
    # Get current parameters as dict
    params = {name: param.clone() for name, param in model.named_parameters()}
    
    for _ in range(steps):
        # Forward with current params
        logits = model(support_x, params)
        loss = F.cross_entropy(logits, support_y)
        
        # Compute gradients wrt params
        grads = torch.autograd.grad(loss, params.values(), create_graph=True)
        
        # Manual SGD update
        params = {name: param - lr * grad 
                 for (name, param), grad in zip(params.items(), grads)}
    
    return params

print("="*70)
print("PROPER MAML IMPLEMENTATION")
print("="*70)

model = MetaNet().to(device)
meta_opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nMeta-training (100 epochs)...")

for epoch in range(100):
    meta_loss_total = 0
    
    for _ in range(5):  # 5 tasks per epoch
        # Create task
        data, labels = create_task()
        support_x, support_y = data[:70], labels[:70]
        query_x, query_y = data[70:], labels[70:]
        
        # Inner loop - get adapted parameters
        adapted_params = inner_loop_manual(model, support_x, support_y, steps=10, lr=0.01)
        
        # Compute meta-loss on query set with adapted params
        query_logits = model(query_x, adapted_params)
        meta_loss = F.cross_entropy(query_logits, query_y)
        meta_loss_total += meta_loss
    
    # Meta-optimization
    meta_opt.zero_grad()
    meta_loss_total.backward()
    meta_opt.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}: Loss={meta_loss_total.item():.4f}")

print(f"\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING")
print("="*70)

accs = []
for i in range(10):
    data, labels = create_task()
    support_x, support_y = data[:70], labels[:70]
    query_x, query_y = data[70:], labels[70:]
    
    # Adapt using inner loop
    adapted_params = inner_loop_manual(model, support_x, support_y, steps=25, lr=0.01)
    
    # Test on query
    with torch.no_grad():
        logits = model(query_x, adapted_params)
        acc = (logits.argmax(dim=1) == query_y).float().mean().item()
        accs.append(acc)
        print(f"Task {i+1}: {acc*100:.1f}%")

avg = np.mean(accs)
print(f"\n{'='*70}")
print(f"Average: {avg*100:.1f}%")
print(f"{'='*70}")

if avg >= 0.80:
    print("✅ WORKING!")
else:
    print("⚠️ Still needs work")
