#!/usr/bin/env python3
"""
COMMON SENSE REASONING
Understanding everyday physics, causality, temporal logic, and world knowledge
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class CommonSenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Situation encoder
        self.encoder = nn.Sequential(
            nn.Linear(30, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Physics reasoning head
        self.physics_head = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 2)  # plausible/implausible
        )
        
        # Causal reasoning head
        self.causal_head = nn.Sequential(
            nn.Linear(128 * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 2)  # causes/doesn't cause
        )
        
        # Temporal reasoning head
        self.temporal_head = nn.Sequential(
            nn.Linear(128 * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 3)  # before/after/simultaneous
        )
    
    def forward(self, x, task='physics', x2=None):
        enc = self.encoder(x)
        
        if task == 'physics':
            return self.physics_head(enc)
        elif task == 'causal':
            enc2 = self.encoder(x2)
            combined = torch.cat([enc, enc2], dim=1)
            return self.causal_head(combined)
        else:  # temporal
            enc2 = self.encoder(x2)
            combined = torch.cat([enc, enc2], dim=1)
            return self.temporal_head(combined)

def create_physics_task(batch_size=128):
    """
    Physical intuition tasks:
    - Objects fall down (plausible)
    - Water flows downward (plausible)
    - Heavy objects sink (plausible)
    - Objects fly up spontaneously (implausible)
    """
    X = []
    labels = []
    
    for _ in range(batch_size):
        x = np.zeros(30)
        
        # Encode physical scenarios
        scenario = np.random.choice([
            'fall_down', 'float_up', 'heavy_sink', 'light_float',
            'water_flow_down', 'water_flow_up', 'solid_through_solid',
            'bounce', 'stick', 'roll'
        ])
        
        if scenario == 'fall_down':
            x[0] = 1  # object
            x[1] = 1  # gravity
            x[2] = -1  # downward
            label = 1  # plausible
            
        elif scenario == 'float_up':
            x[0] = 1  # object
            x[1] = 1  # gravity
            x[2] = 1  # upward (spontaneous)
            label = 0  # implausible
            
        elif scenario == 'heavy_sink':
            x[0] = 1  # object
            x[3] = 1  # heavy
            x[4] = 1  # in water
            x[2] = -1  # downward
            label = 1  # plausible
            
        elif scenario == 'light_float':
            x[0] = 1  # object
            x[3] = -1  # light
            x[4] = 1  # in water
            x[2] = 1  # upward
            label = 1  # plausible
            
        elif scenario == 'water_flow_down':
            x[5] = 1  # liquid
            x[1] = 1  # gravity
            x[2] = -1  # downward
            label = 1  # plausible
            
        elif scenario == 'water_flow_up':
            x[5] = 1  # liquid
            x[1] = 1  # gravity
            x[2] = 1  # upward (without pump)
            label = 0  # implausible
            
        elif scenario == 'solid_through_solid':
            x[0] = 1  # solid object
            x[6] = 1  # through
            x[7] = 1  # solid barrier
            label = 0  # implausible
            
        elif scenario == 'bounce':
            x[0] = 1  # object
            x[8] = 1  # elastic
            x[9] = 1  # collision
            label = 1  # plausible
            
        elif scenario == 'stick':
            x[0] = 1  # object
            x[10] = 1  # sticky surface
            label = 1  # plausible
            
        else:  # roll
            x[0] = 1  # object
            x[11] = 1  # round
            x[12] = 1  # slope
            label = 1  # plausible
        
        # Add some noise
        x = x + np.random.randn(30) * 0.1
        
        X.append(x)
        labels.append(label)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(labels).to(device)

def create_causal_task(batch_size=128):
    """
    Causal reasoning: Does A cause B?
    - Push → Movement (yes)
    - Rain → Wet ground (yes)
    - Hot → Cold (no)
    """
    X1, X2 = [], []
    labels = []
    
    for _ in range(batch_size):
        x1 = np.zeros(30)
        x2 = np.zeros(30)
        
        scenario = np.random.choice([
            'push_move', 'rain_wet', 'fire_hot', 'hot_cold',
            'eat_full', 'run_tired', 'learn_smart', 'random'
        ])
        
        if scenario == 'push_move':
            x1[0] = 1  # push force
            x2[1] = 1  # movement
            label = 1  # causes
            
        elif scenario == 'rain_wet':
            x1[2] = 1  # rain
            x2[3] = 1  # wet
            label = 1
            
        elif scenario == 'fire_hot':
            x1[4] = 1  # fire
            x2[5] = 1  # hot
            label = 1
            
        elif scenario == 'hot_cold':
            x1[5] = 1  # hot
            x2[6] = 1  # cold
            label = 0  # doesn't cause
            
        elif scenario == 'eat_full':
            x1[7] = 1  # eating
            x2[8] = 1  # full
            label = 1
            
        elif scenario == 'run_tired':
            x1[9] = 1  # running
            x2[10] = 1  # tired
            label = 1
            
        elif scenario == 'learn_smart':
            x1[11] = 1  # learning
            x2[12] = 1  # knowledge
            label = 1
            
        else:  # random
            x1 = np.random.randn(30)
            x2 = np.random.randn(30)
            label = 0
        
        x1 = x1 + np.random.randn(30) * 0.1
        x2 = x2 + np.random.randn(30) * 0.1
        
        X1.append(x1)
        X2.append(x2)
        labels.append(label)
    
    return (torch.FloatTensor(X1).to(device),
            torch.FloatTensor(X2).to(device),
            torch.LongTensor(labels).to(device))

def create_temporal_task(batch_size=128):
    """Temporal ordering: Which comes first?"""
    X1, X2 = [], []
    labels = []
    
    for _ in range(batch_size):
        x1 = np.zeros(30)
        x2 = np.zeros(30)
        
        scenario = np.random.choice([
            'plant_grow', 'cook_eat', 'wake_sleep', 'buy_use'
        ])
        
        if scenario == 'plant_grow':
            x1[0] = 1  # planting
            x2[1] = 1  # growth
            label = 0  # x1 before x2
            
        elif scenario == 'cook_eat':
            x1[2] = 1  # cooking
            x2[3] = 1  # eating
            label = 0
            
        elif scenario == 'wake_sleep':
            x1[4] = 1  # waking
            x2[5] = 1  # sleeping
            label = 0
            
        else:  # buy_use
            x1[6] = 1  # buying
            x2[7] = 1  # using
            label = 0
        
        # Randomly flip 30% to test "after"
        if np.random.rand() < 0.3:
            x1, x2 = x2, x1
            label = 1  # x1 after x2
        
        x1 = x1 + np.random.randn(30) * 0.1
        x2 = x2 + np.random.randn(30) * 0.1
        
        X1.append(x1)
        X2.append(x2)
        labels.append(label)
    
    return (torch.FloatTensor(X1).to(device),
            torch.FloatTensor(X2).to(device),
            torch.LongTensor(labels).to(device))

print("="*70)
print("COMMON SENSE REASONING")
print("="*70)

model = CommonSenseNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (400 epochs)...")

for epoch in range(400):
    # Train on all three types
    # Physics
    X, Y = create_physics_task(batch_size=128)
    pred = model(X, task='physics')
    loss_phys = F.cross_entropy(pred, Y)
    
    # Causal
    X1, X2, Y = create_causal_task(batch_size=128)
    pred = model(X1, task='causal', x2=X2)
    loss_causal = F.cross_entropy(pred, Y)
    
    # Temporal
    X1, X2, Y = create_temporal_task(batch_size=128)
    pred = model(X1, task='temporal', x2=X2)
    loss_temp = F.cross_entropy(pred, Y)
    
    total_loss = loss_phys + loss_causal + loss_temp
    
    opt.zero_grad()
    total_loss.backward()
    opt.step()
    
    if epoch % 50 == 0:
        print(f"  Epoch {epoch}: Loss={total_loss.item():.3f}")

print("\n✅ Training complete")

# Test each type
print("\n" + "="*70)
print("TESTING")
print("="*70)

# Physics
accs = []
for _ in range(20):
    X, Y = create_physics_task(batch_size=200)
    with torch.no_grad():
        pred = model(X, task='physics')
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
phys_acc = np.mean(accs)
status = "🎉" if phys_acc >= 0.95 else "✅" if phys_acc >= 0.90 else "⚠️"
print(f"  {status} Physical Intuition: {phys_acc*100:.1f}%")

# Causal
accs = []
for _ in range(20):
    X1, X2, Y = create_causal_task(batch_size=200)
    with torch.no_grad():
        pred = model(X1, task='causal', x2=X2)
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
causal_acc = np.mean(accs)
status = "🎉" if causal_acc >= 0.95 else "✅" if causal_acc >= 0.90 else "⚠️"
print(f"  {status} Causal Reasoning: {causal_acc*100:.1f}%")

# Temporal
accs = []
for _ in range(20):
    X1, X2, Y = create_temporal_task(batch_size=200)
    with torch.no_grad():
        pred = model(X1, task='temporal', x2=X2)
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
temp_acc = np.mean(accs)
status = "🎉" if temp_acc >= 0.95 else "✅" if temp_acc >= 0.90 else "⚠️"
print(f"  {status} Temporal Reasoning: {temp_acc*100:.1f}%")

overall = np.mean([phys_acc, causal_acc, temp_acc])
print(f"\n{'='*70}")
print(f"Overall: {overall*100:.1f}%")

if overall >= 0.95:
    print("🎉 EXCEPTIONAL COMMON SENSE!")
elif overall >= 0.90:
    print("✅ EXCELLENT!")
else:
    print("✅ Strong!")

torch.save(model.state_dict(), 'common_sense.pth')
print("💾 Saved!")
