#!/usr/bin/env python3
"""
THEORY OF MIND 2.0
Deep social intelligence: understanding beliefs, intentions, emotions, and deception
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda')
print(f"Device: {device}\n")

class TheoryOfMindNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Agent state encoder
        self.agent_encoder = nn.Sequential(
            nn.Linear(40, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )
        
        # Mental state predictor (beliefs, intentions)
        self.mental_state = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 64)  # Mental state embedding
        )
        
        # Emotion classifier
        self.emotion_head = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 6)  # happy, sad, angry, fear, surprise, neutral
        )
        
        # Intention predictor
        self.intention_head = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 8)  # help, harm, seek, avoid, etc.
        )
        
        # Belief tracker (what agent believes vs reality)
        self.belief_head = nn.Sequential(
            nn.Linear(128 * 2, 256),  # agent state + world state
            nn.ReLU(),
            nn.Linear(256, 2)  # true belief / false belief
        )
        
        # Deception detector
        self.deception_head = nn.Sequential(
            nn.Linear(128 * 2, 256),  # stated vs actual
            nn.ReLU(),
            nn.Linear(256, 2)  # honest / deceptive
        )
    
    def forward(self, agent_state, task='emotion', world_state=None, stated_belief=None):
        enc = self.agent_encoder(agent_state)
        mental = self.mental_state(enc)
        
        if task == 'emotion':
            return self.emotion_head(mental)
        elif task == 'intention':
            return self.intention_head(mental)
        elif task == 'belief':
            world_enc = self.agent_encoder(world_state)
            combined = torch.cat([enc, world_enc], dim=1)
            return self.belief_head(combined)
        else:  # deception
            stated_enc = self.agent_encoder(stated_belief)
            combined = torch.cat([enc, stated_enc], dim=1)
            return self.deception_head(combined)

def create_emotion_task(batch_size=128):
    """Recognize emotions from agent behavior"""
    X = []
    labels = []
    
    emotions = ['happy', 'sad', 'angry', 'fear', 'surprise', 'neutral']
    
    for _ in range(batch_size):
        x = np.zeros(40)
        emotion_id = np.random.randint(0, 6)
        
        if emotion_id == 0:  # happy
            x[0] = 1  # smiling
            x[1] = 1  # positive energy
            x[2] = 0.8  # approach behavior
        elif emotion_id == 1:  # sad
            x[3] = 1  # frowning
            x[4] = -1  # low energy
            x[5] = -0.8  # withdrawn
        elif emotion_id == 2:  # angry
            x[6] = 1  # aggressive posture
            x[7] = 1  # high energy
            x[8] = 1  # confrontational
        elif emotion_id == 3:  # fear
            x[9] = 1  # avoidance
            x[10] = 1  # tense
            x[11] = -1  # retreat
        elif emotion_id == 4:  # surprise
            x[12] = 1  # wide eyes
            x[13] = 1  # sudden reaction
            x[14] = 0  # neutral valence
        else:  # neutral
            x[15] = 1  # calm
            x[16] = 0  # neutral energy
        
        x = x + np.random.randn(40) * 0.1
        X.append(x)
        labels.append(emotion_id)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(labels).to(device)

def create_intention_task(batch_size=128):
    """Predict what agent intends to do"""
    X = []
    labels = []
    
    intentions = ['help', 'harm', 'seek_food', 'seek_shelter', 'avoid', 'explore', 'rest', 'compete']
    
    for _ in range(batch_size):
        x = np.zeros(40)
        intent_id = np.random.randint(0, 8)
        
        if intent_id == 0:  # help
            x[0] = 1  # approach
            x[1] = 1  # prosocial
            x[2] = 1  # cooperation signal
        elif intent_id == 1:  # harm
            x[3] = 1  # aggressive
            x[4] = 1  # attack posture
            x[5] = -1  # antisocial
        elif intent_id == 2:  # seek food
            x[6] = 1  # hunger signal
            x[7] = 1  # searching
            x[8] = 1  # resource-oriented
        elif intent_id == 3:  # seek shelter
            x[9] = 1  # vulnerability signal
            x[10] = 1  # seeking safety
        elif intent_id == 4:  # avoid
            x[11] = 1  # retreat
            x[12] = 1  # defensive
        elif intent_id == 5:  # explore
            x[13] = 1  # curiosity
            x[14] = 1  # movement
            x[15] = 1  # scanning
        elif intent_id == 6:  # rest
            x[16] = 1  # low activity
            x[17] = 1  # recuperation
        else:  # compete
            x[18] = 1  # rivalry signal
            x[19] = 1  # status-seeking
        
        x = x + np.random.randn(40) * 0.1
        X.append(x)
        labels.append(intent_id)
    
    return torch.FloatTensor(X).to(device), torch.LongTensor(labels).to(device)

def create_belief_task(batch_size=128):
    """
    False belief task (Sally-Anne test):
    Agent believes X is in location A, but it's actually in location B
    """
    agent_states = []
    world_states = []
    labels = []
    
    for _ in range(batch_size):
        agent = np.zeros(40)
        world = np.zeros(40)
        
        # 50% true belief, 50% false belief
        is_true_belief = np.random.rand() > 0.5
        
        if is_true_belief:
            # Agent saw the object moved
            location = np.random.randint(0, 5)
            agent[location] = 1  # agent believes it's here
            world[location] = 1  # it actually is here
            label = 1  # true belief
        else:
            # Agent didn't see it moved
            old_loc = np.random.randint(0, 5)
            new_loc = np.random.randint(0, 5)
            while new_loc == old_loc:
                new_loc = np.random.randint(0, 5)
            
            agent[old_loc] = 1  # agent believes it's here (old)
            world[new_loc] = 1  # but it's actually here (new)
            label = 0  # false belief
        
        agent = agent + np.random.randn(40) * 0.05
        world = world + np.random.randn(40) * 0.05
        
        agent_states.append(agent)
        world_states.append(world)
        labels.append(label)
    
    return (torch.FloatTensor(agent_states).to(device),
            torch.FloatTensor(world_states).to(device),
            torch.LongTensor(labels).to(device))

def create_deception_task(batch_size=128):
    """Detect when agent is being deceptive"""
    actual_states = []
    stated_beliefs = []
    labels = []
    
    for _ in range(batch_size):
        actual = np.zeros(40)
        stated = np.zeros(40)
        
        is_honest = np.random.rand() > 0.4  # 60% deception rate
        
        # True state
        true_value = np.random.randint(0, 10)
        actual[true_value] = 1
        
        if is_honest:
            stated[true_value] = 1
            label = 1  # honest
        else:
            # Lie about it
            false_value = np.random.randint(0, 10)
            while false_value == true_value:
                false_value = np.random.randint(0, 10)
            stated[false_value] = 1
            label = 0  # deceptive
        
        actual = actual + np.random.randn(40) * 0.05
        stated = stated + np.random.randn(40) * 0.05
        
        actual_states.append(actual)
        stated_beliefs.append(stated)
        labels.append(label)
    
    return (torch.FloatTensor(actual_states).to(device),
            torch.FloatTensor(stated_beliefs).to(device),
            torch.LongTensor(labels).to(device))

print("="*70)
print("THEORY OF MIND 2.0")
print("="*70)

model = TheoryOfMindNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

print("\nTraining (500 epochs)...")

for epoch in range(500):
    # Train on all tasks
    # Emotion
    X, Y = create_emotion_task(128)
    pred = model(X, task='emotion')
    loss1 = F.cross_entropy(pred, Y)
    
    # Intention
    X, Y = create_intention_task(128)
    pred = model(X, task='intention')
    loss2 = F.cross_entropy(pred, Y)
    
    # Belief
    agent, world, Y = create_belief_task(128)
    pred = model(agent, task='belief', world_state=world)
    loss3 = F.cross_entropy(pred, Y)
    
    # Deception
    actual, stated, Y = create_deception_task(128)
    pred = model(actual, task='deception', stated_belief=stated)
    loss4 = F.cross_entropy(pred, Y)
    
    total_loss = loss1 + loss2 + loss3 + loss4
    
    opt.zero_grad()
    total_loss.backward()
    opt.step()
    
    if epoch % 100 == 0:
        print(f"  Epoch {epoch}: Loss={total_loss.item():.3f}")

print("\n✅ Training complete")

# Test
print("\n" + "="*70)
print("TESTING")
print("="*70)

# Emotion
accs = []
for _ in range(20):
    X, Y = create_emotion_task(200)
    with torch.no_grad():
        pred = model(X, task='emotion')
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
emotion_acc = np.mean(accs)
status = "🎉" if emotion_acc >= 0.95 else "✅" if emotion_acc >= 0.90 else "⚠️"
print(f"  {status} Emotion Recognition: {emotion_acc*100:.1f}%")

# Intention
accs = []
for _ in range(20):
    X, Y = create_intention_task(200)
    with torch.no_grad():
        pred = model(X, task='intention')
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
intent_acc = np.mean(accs)
status = "🎉" if intent_acc >= 0.95 else "✅" if intent_acc >= 0.90 else "⚠️"
print(f"  {status} Intention Prediction: {intent_acc*100:.1f}%")

# Belief
accs = []
for _ in range(20):
    agent, world, Y = create_belief_task(200)
    with torch.no_grad():
        pred = model(agent, task='belief', world_state=world)
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
belief_acc = np.mean(accs)
status = "🎉" if belief_acc >= 0.95 else "✅" if belief_acc >= 0.90 else "⚠️"
print(f"  {status} Belief Attribution: {belief_acc*100:.1f}%")

# Deception
accs = []
for _ in range(20):
    actual, stated, Y = create_deception_task(200)
    with torch.no_grad():
        pred = model(actual, task='deception', stated_belief=stated)
        acc = (pred.argmax(1) == Y).float().mean().item()
        accs.append(acc)
deception_acc = np.mean(accs)
status = "🎉" if deception_acc >= 0.95 else "✅" if deception_acc >= 0.90 else "⚠️"
print(f"  {status} Deception Detection: {deception_acc*100:.1f}%")

overall = np.mean([emotion_acc, intent_acc, belief_acc, deception_acc])
print(f"\n{'='*70}")
print(f"Overall Theory of Mind: {overall*100:.1f}%")

if overall >= 0.95:
    print("🎉 EXCEPTIONAL!")
elif overall >= 0.90:
    print("✅ EXCELLENT!")
else:
    print("✅ Strong!")

torch.save(model.state_dict(), 'theory_of_mind_v2.pth')
print("💾 Saved!")
