#!/usr/bin/env python3
"""
Fix meta-learning by using DIVERSE tasks
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

# Seeds
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.empty_cache()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

class MetaNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
    def forward(self, x):
        return self.net(x)

def create_task():
    """Create a NEW random task each time"""
    w = np.random.randn(10, 5) * 2
    b = np.random.randn(5)
    X = np.random.randn(100, 10)
    Y = (X @ w + b).argmax(axis=1)
    X = torch.FloatTensor(X).to(device)
    Y = torch.LongTensor(Y).to(device)
    return X, Y

print("="*70)
print("FIXED META-LEARNING TEST")
print("="*70)

model = MetaNet().to(device)
meta_opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nMeta-training with DIFFERENT tasks each epoch...")
losses = []

for epoch in range(100):
    epoch_loss = 0
    
    # 5 different tasks per epoch
    for _ in range(5):
        # NEW task each time!
        data, labels = create_task()
        support_x, support_y = data[:70], labels[:70]
        query_x, query_y = data[70:], labels[70:]
        
        # Inner loop
        adapted = deepcopy(model)
        inner_opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
        
        for _ in range(10):
            pred = adapted(support_x)
            loss = F.cross_entropy(pred, support_y)
            inner_opt.zero_grad()
            loss.backward()
            inner_opt.step()
        
        # Meta loss
        pred_query = adapted(query_x)
        meta_loss = F.cross_entropy(pred_query, query_y)
        epoch_loss += meta_loss
    
    # Meta update
    meta_opt.zero_grad()
    epoch_loss.backward()
    meta_opt.step()
    
    avg_loss = epoch_loss.item() / 5
    losses.append(avg_loss)
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}: Avg Loss={avg_loss:.4f}")

print(f"\nFinal loss: {losses[-1]:.4f}")
print(f"Starting loss: {losses[0]:.4f}")
print(f"Loss decreased by: {losses[0] - losses[-1]:.4f}")

# Test adaptation
print("\n" + "="*70)
print("TESTING ADAPTATION")
print("="*70)

accs = []
for i in range(10):
    data, labels = create_task()
    support_x, support_y = data[:70], labels[:70]
    query_x, query_y = data[70:], labels[70:]
    
    adapted = MetaNet().to(device)
    adapted.load_state_dict(model.state_dict())
    
    opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
    for _ in range(25):
        pred = adapted(support_x)
        loss = F.cross_entropy(pred, support_y)
        opt.zero_grad()
        loss.backward()
        opt.step()
    
    with torch.no_grad():
        pred = adapted(query_x)
        acc = (pred.argmax(dim=1) == query_y).float().mean().item()
        accs.append(acc)
        print(f"Task {i+1}: {acc*100:.1f}%")

avg = np.mean(accs)
print(f"\n{'='*70}")
print(f"Average: {avg*100:.1f}%")
print(f"{'='*70}")

if avg >= 0.85:
    print("✅ WORKING! Meta-learning is fixed!")
elif avg >= 0.75:
    print("⚠️ Better, but needs more training")
else:
    print("❌ Still broken")
