#!/usr/bin/env python3
"""
ADVANCED REASONING
Multi-step inference, chain-of-thought, logical deduction
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

class ReasoningNet(nn.Module):
    """Network that performs multi-step reasoning"""
    def __init__(self):
        super().__init__()
        # Encoder for problem representation
        self.encoder = nn.Sequential(
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        
        # Reasoning steps (recurrent processing)
        self.reasoning_cell = nn.GRUCell(64, 128)
        
        # Decoder for answer
        self.decoder = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
        
    def forward(self, x, n_steps=5):
        # Encode problem
        encoded = self.encoder(x)
        
        # Multi-step reasoning
        hidden = torch.zeros(x.size(0), 128, device=x.device)
        for _ in range(n_steps):
            hidden = self.reasoning_cell(encoded, hidden)
        
        # Decode answer
        output = self.decoder(hidden)
        return output

def create_reasoning_problem(problem_type='transitive'):
    """
    Create reasoning problems:
    - Transitive: If A>B and B>C, then A>C
    - Logical: AND, OR, NOT operations
    - Numerical: Multi-step arithmetic
    """
    batch_size = 64
    X = []
    Y = []
    
    for _ in range(batch_size):
        if problem_type == 'transitive':
            # A > B > C, what's the relationship?
            a, b, c = np.random.randint(1, 10, 3)
            a, b, c = sorted([a, b, c], reverse=True)
            
            # Input: [a, b, b, c, a, c] (relationships)
            x = np.zeros(20)
            x[0], x[1] = a, b  # A > B
            x[5], x[6] = b, c  # B > C
            x[10], x[11] = a, c  # Query: A ? C
            
            # Answer: 0 = greater, 1 = less, 2 = equal
            y = 0  # A > C
            
        elif problem_type == 'logical':
            # Logical operations: A AND B, B OR C, what about A AND C?
            a, b, c = np.random.randint(0, 2, 3)
            
            x = np.zeros(20)
            x[0], x[1], x[2] = a, b, c
            x[5] = a & b  # A AND B
            x[6] = b | c  # B OR C
            
            y = a & c  # Answer: A AND C
            
        elif problem_type == 'numerical':
            # Multi-step arithmetic: (A + B) * C - D
            a, b, c, d = np.random.randint(1, 10, 4)
            
            x = np.zeros(20)
            x[0], x[1], x[2], x[3] = a, b, c, d
            
            result = (a + b) * c - d
            y = result % 10  # Keep answer small
        
        X.append(x)
        Y.append(y)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(Y).to(device)

print("="*70)
print("ADVANCED REASONING")
print("="*70)

# Test each reasoning type
for problem_type in ['transitive', 'logical', 'numerical']:
    print(f"\n{'='*70}")
    print(f"TRAINING: {problem_type.upper()} REASONING")
    print(f"{'='*70}")
    
    model = ReasoningNet().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Train
    for epoch in range(200):
        X, Y = create_reasoning_problem(problem_type)
        
        pred = model(X, n_steps=5)
        loss = F.cross_entropy(pred, Y)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if epoch % 50 == 0:
            acc = (pred.argmax(1) == Y).float().mean()
            print(f"  Epoch {epoch}: Loss={loss.item():.3f}, Acc={acc.item()*100:.1f}%")
    
    # Test
    print(f"\nTesting {problem_type} reasoning...")
    test_accs = []
    for _ in range(10):
        X, Y = create_reasoning_problem(problem_type)
        with torch.no_grad():
            pred = model(X, n_steps=5)
            acc = (pred.argmax(1) == Y).float().mean().item()
            test_accs.append(acc)
    
    avg_acc = np.mean(test_accs)
    print(f"Average Test Accuracy: {avg_acc*100:.1f}%")
    
    if avg_acc >= 0.80:
        print(f"✅ {problem_type.upper()} reasoning: WORKING!")
    else:
        print(f"⚠️ {problem_type.upper()} reasoning: needs improvement")

print("\n" + "="*70)
print("ADVANCED REASONING TEST COMPLETE")
print("="*70)
