#!/usr/bin/env python3
"""
Meta-learning on visual pattern tasks instead of random linear tasks
Much simpler and more stable
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 4 * 4, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
    def forward(self, x):
        return self.net(x)

def create_visual_task():
    """Create simple visual patterns - each class is a different shape"""
    X = []
    Y = []
    
    for cls in range(5):
        for _ in range(20):
            img = np.zeros((8, 8))
            # Different pattern for each class
            if cls == 0:  # Horizontal line
                img[4, :] = 1
            elif cls == 1:  # Vertical line
                img[:, 4] = 1
            elif cls == 2:  # Diagonal
                np.fill_diagonal(img, 1)
            elif cls == 3:  # Square
                img[2:6, 2:6] = 1
            elif cls == 4:  # Cross
                img[4, :] = 1
                img[:, 4] = 1
            
            # Add noise
            img += np.random.randn(8, 8) * 0.3
            X.append(img)
            Y.append(cls)
    
    X = torch.FloatTensor(X).unsqueeze(1).to(device)
    Y = torch.LongTensor(Y).to(device)
    
    # Shuffle
    perm = torch.randperm(len(Y))
    return X[perm], Y[perm]

print("="*70)
print("META-LEARNING ON VISUAL PATTERNS")
print("="*70)

model = ConvNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (200 epochs)...")
for epoch in range(200):
    X, Y = create_visual_task()
    
    # Split
    X_support, Y_support = X[:80], Y[:80]
    X_query, Y_query = X[80:], Y[80:]
    
    # Train on support
    for _ in range(10):
        pred = model(X_support)
        loss = F.cross_entropy(pred, Y_support)
        opt.zero_grad()
        loss.backward()
        opt.step()
    
    if epoch % 40 == 0:
        # Test on query
        with torch.no_grad():
            pred = model(X_query)
            acc = (pred.argmax(1) == Y_query).float().mean()
            print(f"Epoch {epoch}: Acc={acc.item()*100:.1f}%")

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING")
print("="*70)

accs = []
for i in range(10):
    X, Y = create_visual_task()
    X_support, Y_support = X[:80], Y[:80]
    X_query, Y_query = X[80:], Y[80:]
    
    # Quick adaptation
    test_model = ConvNet().to(device)
    test_model.load_state_dict(model.state_dict())
    test_opt = torch.optim.Adam(test_model.parameters(), lr=0.01)
    
    for _ in range(20):
        pred = test_model(X_support)
        loss = F.cross_entropy(pred, Y_support)
        test_opt.zero_grad()
        loss.backward()
        test_opt.step()
    
    with torch.no_grad():
        pred = test_model(X_query)
        acc = (pred.argmax(1) == Y_query).float().mean().item()
        accs.append(acc)
        print(f"Task {i+1}: {acc*100:.1f}%")

avg = np.mean(accs)
print(f"\n{'='*70}")
print(f"AVERAGE: {avg*100:.1f}%")
if avg >= 0.90:
    print("🎉 EXCELLENT!")
elif avg >= 0.70:
    print("✅ Good!")
else:
    print("⚠️ Needs work")
