#!/usr/bin/env python3
"""
ULTRA-SIMPLE ARITHMETIC
Just learn A + B = C
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class AdditionNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 11)  # Output: 0-10
        )
    
    def forward(self, x):
        return self.net(x)

def create_addition_problem():
    """Just A + B where A, B in [0, 5]"""
    batch_size = 64
    X = []
    Y = []
    
    for _ in range(batch_size):
        a = np.random.randint(0, 6)
        b = np.random.randint(0, 6)
        result = a + b
        
        X.append([a, b])
        Y.append(result)
    
    X = np.array(X)
    Y = np.array(Y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("SIMPLE ADDITION TEST: A + B")
print("="*70)

model = AdditionNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)

print("\nTraining (200 epochs)...")
for epoch in range(200):
    X, Y = create_addition_problem()
    
    pred = model(X)
    loss = F.cross_entropy(pred, Y)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 40 == 0:
        acc = (pred.argmax(1) == Y).float().mean()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")

# Test
print("\nTesting on new problems...")
test_accs = []
for _ in range(10):
    X, Y = create_addition_problem()
    with torch.no_grad():
        pred = model(X)
        acc = (pred.argmax(1) == Y).float().mean().item()
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
print(f"\nAverage: {avg_acc*100:.1f}%")

if avg_acc >= 0.95:
    print("✅ Addition works! Can extend to more operations.")
elif avg_acc >= 0.80:
    print("⚠️ Close, needs more training")
else:
    print("❌ Even addition is broken - numerical reasoning may not be learnable this way")
    print("\nRecommendation: Mark reasoning as PARTIAL (2/3 types working)")
