#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Fresh seeds
np.random.seed(123)
torch.manual_seed(123)

device = torch.device('cuda')
print(f"Testing on {device}\n")

# Tiny network
net = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(), 
    nn.Linear(32, 5)
).to(device)

# One simple task
X = torch.randn(100, 10, device=device)
y = torch.randint(0, 5, (100,), device=device)

# Can it learn THIS ONE TASK with normal training?
opt = torch.optim.Adam(net.parameters(), lr=0.01)

print("Training on one task for 100 steps:")
for i in range(100):
    pred = net(X)
    loss = F.cross_entropy(pred, y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if i % 20 == 0:
        acc = (pred.argmax(1) == y).float().mean()
        print(f"  Step {i}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")

final_acc = (net(X).argmax(1) == y).float().mean()
print(f"\nFinal: {final_acc.item()*100:.1f}%")

if final_acc > 0.9:
    print("✅ Basic learning works")
else:
    print("❌ Even basic learning is broken - PyTorch installation issue")
