#!/usr/bin/env python3
"""
EXACT CODE THAT GOT 87% - NO CHANGES
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

class MetaNetFinal(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
    def forward(self, x):
        return self.net(x)

class MetaLearnerFinal:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = MetaNetFinal().to(self.device)
    
    def create_good_task(self):
        w = np.random.randn(10, 5) * 2
        b = np.random.randn(5)
        X = np.random.randn(100, 10)
        Y = (X @ w + b).argmax(axis=1)
        X = torch.FloatTensor(X).to(self.device)
        Y = torch.LongTensor(Y).to(self.device)
        return X, Y
    
    def meta_train(self, n_epochs=400):
        meta_opt = torch.optim.Adam(self.model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(meta_opt, step_size=100, gamma=0.5)
        for epoch in range(n_epochs):
            epoch_loss = 0
            for _ in range(10):
                data, labels = self.create_good_task()
                support_x, support_y = data[:70], labels[:70]
                query_x, query_y = data[70:], labels[70:]
                adapted = deepcopy(self.model)
                inner_opt = torch.optim.SGD(adapted.parameters(), lr=0.01, momentum=0.9)
                for _ in range(10):
                    pred = adapted(support_x)
                    loss = F.cross_entropy(pred, support_y)
                    inner_opt.zero_grad()
                    loss.backward()
                    inner_opt.step()
                pred_query = adapted(query_x)
                meta_loss = F.cross_entropy(pred_query, query_y)
                epoch_loss += meta_loss
            meta_opt.zero_grad()
            epoch_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            meta_opt.step()
            scheduler.step()
            if (epoch + 1) % 80 == 0:
                print(f"Epoch {epoch+1}: Loss: {epoch_loss.item():.4f}")
        print("✅ Meta-training complete")
    
    def test_adaptation(self, n_tests=10):
        accs = []
        for _ in range(n_tests):
            data, labels = self.create_good_task()
            support_x, support_y = data[:70], labels[:70]
            query_x, query_y = data[70:], labels[70:]
            adapted = MetaNetFinal().to(self.device)
            adapted.load_state_dict(self.model.state_dict())
            opt = torch.optim.SGD(adapted.parameters(), lr=0.01, momentum=0.9)
            for _ in range(25):
                pred = adapted(support_x)
                loss = F.cross_entropy(pred, support_y)
                opt.zero_grad()
                loss.backward()
                opt.step()
            with torch.no_grad():
                pred = adapted(query_x)
                acc = (pred.argmax(dim=1) == query_y).float().mean().item()
                accs.append(acc)
        return accs

ml = MetaLearnerFinal()
ml.meta_train(n_epochs=400)
print("\nTesting on 10 new tasks...")
accs = ml.test_adaptation(n_tests=10)
for i, acc in enumerate(accs, 1):
    print(f"Task {i}: {acc*100:.1f}%")
avg = np.mean(accs)
std = np.std(accs)
print(f"\nAverage: {avg*100:.1f}% (±{std*100:.1f}%)")
