#!/usr/bin/env python3
"""
DIAGNOSTIC: Find out WHY meta-learning is broken
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys

print("="*70)
print("META-LEARNING DIAGNOSTIC")
print("="*70)

# Clear GPU memory completely
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    print(f"✅ Cleared CUDA cache")

# Set seeds
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
print(f"✅ Set random seeds")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✅ Device: {device}")

# Test 1: Can we generate valid tasks?
print("\n" + "="*70)
print("TEST 1: Task Generation")
print("="*70)
w = np.random.randn(10, 5) * 2
b = np.random.randn(5)
X = np.random.randn(100, 10)
Y = (X @ w + b).argmax(axis=1)
print(f"Task shape: X={X.shape}, Y={Y.shape}")
print(f"Classes in Y: {np.unique(Y)}")
print(f"Class distribution: {np.bincount(Y)}")
if len(np.unique(Y)) != 5:
    print("❌ PROBLEM: Task doesn't have all 5 classes!")
    sys.exit(1)
else:
    print("✅ Task generation working")

# Test 2: Can a simple model learn this task?
print("\n" + "="*70)
print("TEST 2: Basic Learning (no meta-learning)")
print("="*70)

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 5)
        )
    def forward(self, x):
        return self.net(x)

model = SimpleNet().to(device)
X_t = torch.FloatTensor(X).to(device)
Y_t = torch.LongTensor(Y).to(device)

opt = torch.optim.Adam(model.parameters(), lr=0.01)

print("Training for 50 steps...")
for step in range(50):
    pred = model(X_t)
    loss = F.cross_entropy(pred, Y_t)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if step % 10 == 0:
        acc = (pred.argmax(dim=1) == Y_t).float().mean().item()
        print(f"  Step {step}: Loss={loss.item():.4f}, Acc={acc*100:.1f}%")

final_acc = (pred.argmax(dim=1) == Y_t).float().mean().item()
print(f"\nFinal accuracy: {final_acc*100:.1f}%")

if final_acc < 0.80:
    print("❌ PROBLEM: Simple model can't even learn one task!")
    print("   This means either:")
    print("   - GPU is broken")
    print("   - PyTorch is broken")
    print("   - Task generation is broken")
    sys.exit(1)
else:
    print("✅ Basic learning works")

# Test 3: Can meta-learning work on ONE task?
print("\n" + "="*70)
print("TEST 3: Meta-Learning on Single Task")
print("="*70)

meta_model = SimpleNet().to(device)
meta_opt = torch.optim.Adam(meta_model.parameters(), lr=0.001)

print("Meta-training for 20 epochs on ONE task...")
for epoch in range(20):
    # Same task every time (for testing)
    data, labels = X_t, Y_t
    support_x, support_y = data[:70], labels[:70]
    query_x, query_y = data[70:], labels[70:]
    
    # Inner loop
    from copy import deepcopy
    adapted = deepcopy(meta_model)
    inner_opt = torch.optim.SGD(adapted.parameters(), lr=0.01)
    
    for _ in range(5):
        pred = adapted(support_x)
        loss = F.cross_entropy(pred, support_y)
        inner_opt.zero_grad()
        loss.backward()
        inner_opt.step()
    
    # Meta loss
    pred_query = adapted(query_x)
    meta_loss = F.cross_entropy(pred_query, query_y)
    
    meta_opt.zero_grad()
    meta_loss.backward()
    meta_opt.step()
    
    if epoch % 5 == 0:
        print(f"  Epoch {epoch}: Meta-loss={meta_loss.item():.4f}")

print(f"\nFinal meta-loss: {meta_loss.item():.4f}")

if meta_loss.item() > 2.0:
    print("❌ PROBLEM: Meta-learning not working even on one task!")
    print("   PyTorch installation might be corrupted")
else:
    print("✅ Meta-learning mechanism works")

print("\n" + "="*70)
print("DIAGNOSIS COMPLETE")
print("="*70)
print("\nIf all tests passed, meta-learning SHOULD work.")
print("If any test failed, we know exactly what's broken.")
