#!/usr/bin/env python3
"""
ULTRA-SIMPLE META-LEARNING
No fancy tricks, just what works
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# FRESH seeds
torch.manual_seed(999)
np.random.seed(999)

device = torch.device('cuda')
print(f"Device: {device}")

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
    def forward(self, x):
        return self.layers(x)

def make_task():
    """Generate a classification task"""
    w = np.random.randn(10, 5) * 2
    b = np.random.randn(5)
    X = np.random.randn(100, 10)
    Y = (X @ w + b).argmax(axis=1)
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("\n" + "="*70)
print("TRAINING")
print("="*70)

model = SimpleNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.005)

# Train on MANY tasks
print("\nTraining on 2000 tasks...")
for task_num in range(2000):
    X, Y = make_task()
    
    # Split into train/test
    X_train, Y_train = X[:80], Y[:80]
    X_test, Y_test = X[80:], Y[80:]
    
    # Quick fine-tune on this task (like inner loop)
    model.train()
    for _ in range(5):
        pred = model(X_train)
        loss = F.cross_entropy(pred, Y_train)
        opt.zero_grad()
        loss.backward()
        opt.step()
    
    if task_num % 400 == 0:
        # Test on held-out data
        model.eval()
        with torch.no_grad():
            pred = model(X_test)
            acc = (pred.argmax(1) == Y_test).float().mean()
            print(f"  Task {task_num}: Test Acc={acc.item()*100:.1f}%")

print("\n✅ Training complete")

# Final test
print("\n" + "="*70)
print("TESTING ON NEW TASKS")
print("="*70)

model.eval()
test_accs = []

for i in range(20):
    X, Y = make_task()
    X_support, Y_support = X[:80], Y[:80]
    X_query, Y_query = X[80:], Y[80:]
    
    # Fine-tune on support set
    temp_model = SimpleNet().to(device)
    temp_model.load_state_dict(model.state_dict())
    temp_opt = torch.optim.SGD(temp_model.parameters(), lr=0.05)
    
    for _ in range(30):
        pred = temp_model(X_support)
        loss = F.cross_entropy(pred, Y_support)
        temp_opt.zero_grad()
        loss.backward()
        temp_opt.step()
    
    # Test on query
    with torch.no_grad():
        pred = temp_model(X_query)
        acc = (pred.argmax(1) == Y_query).float().mean().item()
        test_accs.append(acc)
        print(f"Task {i+1}: {acc*100:.1f}%")

avg = np.mean(test_accs)
print(f"\n{'='*70}")
print(f"AVERAGE: {avg*100:.1f}%")
print(f"{'='*70}")

if avg >= 0.80:
    print("🎉 SUCCESS!")
