#!/usr/bin/env python3
"""
FIX NUMERICAL REASONING
Simpler arithmetic problems
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class ArithmeticNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 20)  # Output: 0-19 (smaller range)
        )
    
    def forward(self, x):
        return self.net(x)

def create_arithmetic_problem():
    """Simpler arithmetic: A + B, A - B, A * B (small numbers)"""
    batch_size = 64
    X = []
    Y = []
    
    for _ in range(batch_size):
        # Random operation and small numbers
        op = np.random.choice(['add', 'sub', 'mul'])
        a = np.random.randint(1, 6)
        b = np.random.randint(1, 6)
        
        if op == 'add':
            result = a + b
            op_code = 0
        elif op == 'sub':
            result = max(0, a - b)  # No negatives
            op_code = 1
        else:  # mul
            result = a * b
            op_code = 2
        
        # Input: [a, b, op_code, padding...]
        x = np.zeros(10)
        x[0] = a / 10  # Normalize
        x[1] = b / 10
        x[2] = op_code / 10
        
        # Clamp result to output range
        y = min(result, 19)
        
        X.append(x)
        Y.append(y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("NUMERICAL REASONING (FIXED)")
print("="*70)

model = ArithmeticNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (300 epochs)...")
for epoch in range(300):
    X, Y = create_arithmetic_problem()
    
    pred = model(X)
    loss = F.cross_entropy(pred, Y)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 60 == 0:
        acc = (pred.argmax(1) == Y).float().mean()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")

# Test
print("\nTesting...")
test_accs = []
for _ in range(20):
    X, Y = create_arithmetic_problem()
    with torch.no_grad():
        pred = model(X)
        acc = (pred.argmax(1) == Y).float().mean().item()
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
print(f"\nAverage Test Accuracy: {avg_acc*100:.1f}%")

if avg_acc >= 0.85:
    print("✅ NUMERICAL reasoning: FIXED!")
elif avg_acc >= 0.70:
    print("⚠️ Better, but needs more training")
else:
    print("❌ Still broken")
