#!/usr/bin/env python3
"""
COMPLETE ARITHMETIC REASONING
Addition, Subtraction, Multiplication with small numbers
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class ArithmeticNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 128),  # [a, b, operation]
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 26)  # Output: 0-25 (covers all results)
        )
    
    def forward(self, x):
        return self.net(x)

def create_arithmetic_problem():
    """A op B where op is +, -, or *"""
    batch_size = 64
    X = []
    Y = []
    
    for _ in range(batch_size):
        a = np.random.randint(0, 6)
        b = np.random.randint(0, 6)
        op = np.random.randint(0, 3)  # 0=add, 1=sub, 2=mul
        
        if op == 0:  # Addition
            result = a + b
        elif op == 1:  # Subtraction
            result = max(0, a - b)  # No negatives
        else:  # Multiplication
            result = a * b
        
        X.append([a, b, op])
        Y.append(result)
    
    X = np.array(X)
    Y = np.array(Y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("COMPLETE ARITHMETIC: +, -, ×")
print("="*70)

model = ArithmeticNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.005)

print("\nTraining (300 epochs)...")
for epoch in range(300):
    X, Y = create_arithmetic_problem()
    
    pred = model(X)
    loss = F.cross_entropy(pred, Y)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if epoch % 50 == 0:
        acc = (pred.argmax(1) == Y).float().mean()
        print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")

# Test
print("\nTesting...")
test_accs = []
for _ in range(20):
    X, Y = create_arithmetic_problem()
    with torch.no_grad():
        pred = model(X)
        acc = (pred.argmax(1) == Y).float().mean().item()
        test_accs.append(acc)

avg_acc = np.mean(test_accs)
print(f"\nAverage: {avg_acc*100:.1f}%")

if avg_acc >= 0.90:
    print("✅ NUMERICAL REASONING: COMPLETE!")
elif avg_acc >= 0.80:
    print("✅ Good! Close to target.")
else:
    print("⚠️ Needs more work")

# Save if good
if avg_acc >= 0.85:
    torch.save(model.state_dict(), 'arithmetic_model.pth')
    print("💾 Model saved!")
