#!/usr/bin/env python3
"""
Theory of Mind Implementation
Fix Eden's Sally-Anne test failure
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
import json

# =============================================================================
# THEORY OF MIND SCENARIOS
# =============================================================================

class ToMScenario:
    """A Theory of Mind test scenario"""
    
    def __init__(self, scenario_type, agents, events, question, correct_answer, explanation):
        self.type = scenario_type
        self.agents = agents
        self.events = events
        self.question = question
        self.correct_answer = correct_answer
        self.explanation = explanation
    
    def to_text(self):
        """Convert to text format"""
        text = ""
        for event in self.events:
            text += event + "\n"
        text += "\nQuestion: " + self.question
        return text

def generate_false_belief_scenarios():
    """Generate false belief test scenarios"""
    scenarios = []
    
    # Sally-Anne Test
    scenarios.append(ToMScenario(
        scenario_type="false_belief",
        agents=["Sally", "Anne"],
        events=[
            "Sally puts a marble in basket A.",
            "Sally leaves the room.",
            "Anne moves the marble from basket A to basket B.",
            "Sally returns."
        ],
        question="Where will Sally look for the marble?",
        correct_answer="A",
        explanation="Sally didn't see Anne move it, so she believes it's still in A"
    ))
    
    # Chocolate Test
    scenarios.append(ToMScenario(
        scenario_type="false_belief",
        agents=["Max", "Mom"],
        events=[
            "Max puts chocolate in the cupboard.",
            "Max goes outside to play.",
            "Mom moves the chocolate to the fridge.",
            "Max comes back inside."
        ],
        question="Where will Max look for the chocolate?",
        correct_answer="cupboard",
        explanation="Max didn't see Mom move it, so he believes it's in the cupboard"
    ))
    
    # Book Test
    scenarios.append(ToMScenario(
        scenario_type="false_belief",
        agents=["John", "Mary"],
        events=[
            "John puts his book on the table.",
            "John leaves for lunch.",
            "Mary puts the book in the drawer.",
            "John returns."
        ],
        question="Where will John look for his book?",
        correct_answer="table",
        explanation="John believes the book is still on the table"
    ))
    
    # Keys Test
    scenarios.append(ToMScenario(
        scenario_type="false_belief",
        agents=["Tom", "Lisa"],
        events=[
            "Tom places his keys on the counter.",
            "Tom goes to take a shower.",
            "Lisa moves the keys to the hook by the door.",
            "Tom finishes his shower."
        ],
        question="Where will Tom look for his keys?",
        correct_answer="counter",
        explanation="Tom doesn't know Lisa moved them"
    ))
    
    # Toy Test
    scenarios.append(ToMScenario(
        scenario_type="false_belief",
        agents=["Emma", "Dad"],
        events=[
            "Emma hides her toy in the closet.",
            "Emma goes to school.",
            "Dad puts the toy in the toy box.",
            "Emma comes home from school."
        ],
        question="Where will Emma look for her toy?",
        correct_answer="closet",
        explanation="Emma doesn't know Dad moved it"
    ))
    
    # Generate variations with different objects/locations
    objects = ["marble", "ball", "coin", "card", "ring"]
    locations_a = ["box A", "basket A", "drawer A", "shelf A", "bag A"]
    locations_b = ["box B", "basket B", "drawer B", "shelf B", "bag B"]
    agents_list = [("Alice", "Bob"), ("Sam", "Pat"), ("Chris", "Alex")]
    
    for obj, loc_a, loc_b, (agent1, agent2) in zip(objects, locations_a, locations_b, agents_list):
        scenarios.append(ToMScenario(
            scenario_type="false_belief",
            agents=[agent1, agent2],
            events=[
                f"{agent1} puts a {obj} in {loc_a}.",
                f"{agent1} leaves.",
                f"{agent2} moves the {obj} to {loc_b}.",
                f"{agent1} returns."
            ],
            question=f"Where will {agent1} look for the {obj}?",
            correct_answer=loc_a,
            explanation=f"{agent1} didn't see the move"
        ))
    
    return scenarios

def generate_true_belief_scenarios():
    """Generate true belief scenarios (agent knows the current state)"""
    scenarios = []
    
    scenarios.append(ToMScenario(
        scenario_type="true_belief",
        agents=["Sally", "Anne"],
        events=[
            "Sally puts a marble in basket A.",
            "Anne moves the marble from basket A to basket B.",
            "Sally watches Anne move it."
        ],
        question="Where will Sally look for the marble?",
        correct_answer="B",
        explanation="Sally saw Anne move it to B"
    ))
    
    scenarios.append(ToMScenario(
        scenario_type="true_belief",
        agents=["Max"],
        events=[
            "Max puts chocolate in the cupboard.",
            "Max moves the chocolate to the fridge.",
        ],
        question="Where will Max look for the chocolate?",
        correct_answer="fridge",
        explanation="Max moved it himself"
    ))
    
    return scenarios

# =============================================================================
# THEORY OF MIND DATASET
# =============================================================================

class ToMDataset(Dataset):
    """Dataset for Theory of Mind training"""
    
    def __init__(self, scenarios, vocab_size=1000):
        self.scenarios = scenarios
        self.vocab = self._build_vocab()
        self.vocab_size = min(len(self.vocab), vocab_size)
    
    def _build_vocab(self):
        """Build vocabulary from scenarios"""
        vocab = {"<PAD>": 0, "<UNK>": 1}
        idx = 2
        
        for scenario in self.scenarios:
            text = scenario.to_text() + " " + scenario.correct_answer
            for word in text.lower().split():
                if word not in vocab:
                    vocab[word] = idx
                    idx += 1
        
        return vocab
    
    def text_to_indices(self, text, max_len=100):
        """Convert text to indices"""
        words = text.lower().split()
        indices = [self.vocab.get(w, self.vocab["<UNK>"]) for w in words]
        
        # Pad or truncate
        if len(indices) < max_len:
            indices += [self.vocab["<PAD>"]] * (max_len - len(indices))
        else:
            indices = indices[:max_len]
        
        return torch.tensor(indices)
    
    def __len__(self):
        return len(self.scenarios)
    
    def __getitem__(self, idx):
        scenario = self.scenarios[idx]
        
        # Encode scenario text
        text = scenario.to_text()
        text_indices = self.text_to_indices(text)
        
        # Encode correct answer
        answer = scenario.correct_answer.lower()
        answer_indices = self.text_to_indices(answer, max_len=10)
        
        # Create label (1 for false belief, 0 for true belief)
        label = 1 if scenario.type == "false_belief" else 0
        
        return {
            'text': text_indices,
            'answer': answer_indices,
            'label': torch.tensor(label),
            'answer_text': answer
        }

# =============================================================================
# THEORY OF MIND MODEL
# =============================================================================

class TheoryOfMindModel(nn.Module):
    """
    Model that learns to track agent beliefs
    
    Key idea: Learn to separate what YOU know from what the AGENT knows
    """
    
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # LSTM to process events
        self.event_lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Attention to track which events each agent observed
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4, batch_first=True)
        
        # Belief tracker (learns to model agent's mental state)
        self.belief_tracker = nn.LSTM(hidden_dim * 2, hidden_dim, batch_first=True)
        
        # Classifier: Does agent have false belief?
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)  # False belief or true belief
        )
        
        # Answer decoder
        self.answer_decoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.answer_output = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, text, answer=None):
        # Embed text
        embedded = self.embedding(text)  # [batch, seq_len, embed_dim]
        
        # Process events
        events, (h, c) = self.event_lstm(embedded)  # [batch, seq_len, hidden*2]
        
        # Apply attention (which events are relevant?)
        attended, _ = self.attention(events, events, events)
        
        # Track beliefs
        beliefs, _ = self.belief_tracker(attended)
        
        # Get final belief state
        final_belief = beliefs[:, -1, :]  # [batch, hidden]
        
        # Classify: false belief or true belief?
        belief_logits = self.classifier(final_belief)
        
        # Decode answer if provided
        if answer is not None:
            answer_embedded = self.embedding(answer)
            answer_hidden, _ = self.answer_decoder(answer_embedded)
            answer_logits = self.answer_output(answer_hidden)
            return belief_logits, answer_logits
        
        return belief_logits

# =============================================================================
# TRAINING
# =============================================================================

def train_tom_model(epochs=50, batch_size=16):
    """Train Theory of Mind model"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}\n")
    
    # Generate scenarios
    print("Generating Theory of Mind scenarios...")
    false_belief = generate_false_belief_scenarios()
    true_belief = generate_true_belief_scenarios()
    
    # Duplicate true belief scenarios to balance dataset
    true_belief = true_belief * (len(false_belief) // len(true_belief))
    
    all_scenarios = false_belief + true_belief
    random.shuffle(all_scenarios)
    
    # Split train/test
    split = int(0.8 * len(all_scenarios))
    train_scenarios = all_scenarios[:split]
    test_scenarios = all_scenarios[split:]
    
    print(f"Training scenarios: {len(train_scenarios)}")
    print(f"Test scenarios: {len(test_scenarios)}\n")
    
    # Create datasets
    train_dataset = ToMDataset(train_scenarios)
    test_dataset = ToMDataset(test_scenarios, vocab_size=len(train_dataset.vocab))
    test_dataset.vocab = train_dataset.vocab  # Share vocab
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Create model
    model = TheoryOfMindModel(vocab_size=len(train_dataset.vocab)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    print(f"Training for {epochs} epochs...\n")
    
    best_acc = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            text = batch['text'].to(device)
            label = batch['label'].to(device)
            
            optimizer.zero_grad()
            
            logits = model(text)
            loss = F.cross_entropy(logits, label)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pred = logits.argmax(1)
            train_correct += (pred == label).sum().item()
            train_total += label.size(0)
        
        train_acc = 100 * train_correct / train_total
        
        # Test
        model.eval()
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for batch in test_loader:
                text = batch['text'].to(device)
                label = batch['label'].to(device)
                
                logits = model(text)
                pred = logits.argmax(1)
                test_correct += (pred == label).sum().item()
                test_total += label.size(0)
        
        test_acc = 100 * test_correct / test_total
        
        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.1f}%, Test Acc: {test_acc:.1f}%")
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({
                'model': model.state_dict(),
                'vocab': train_dataset.vocab
            }, 'tom_model.pth')
    
    print(f"\n✅ Best Test Accuracy: {best_acc:.1f}%")
    return model, train_dataset.vocab

# =============================================================================
# TESTING
# =============================================================================

def test_sally_anne(model, vocab, device='cuda'):
    """Test on Sally-Anne specifically"""
    print("\n" + "="*70)
    print("TESTING: SALLY-ANNE TEST")
    print("="*70)
    
    scenario = """Sally puts a marble in basket A.
Sally leaves the room.
Anne moves the marble from basket A to basket B.
Sally returns.

Question: Where will Sally look for the marble?"""
    
    print(f"\nScenario:\n{scenario}")
    
    # Tokenize
    words = scenario.lower().split()
    indices = [vocab.get(w, vocab.get("<UNK>")) for w in words]
    
    # Pad
    if len(indices) < 100:
        indices += [vocab["<PAD>"]] * (100 - len(indices))
    else:
        indices = indices[:100]
    
    text_tensor = torch.tensor([indices]).to(device)
    
    # Predict
    model.eval()
    with torch.no_grad():
        logits = model(text_tensor)
        pred = logits.argmax(1).item()
    
    prediction = "FALSE BELIEF" if pred == 1 else "TRUE BELIEF"
    correct_answer = "basket A" if pred == 1 else "basket B"
    
    print(f"\nModel prediction: {prediction}")
    print(f"Sally will look in: {correct_answer}")
    print(f"Correct answer: basket A (FALSE BELIEF)")
    
    if pred == 1:
        print("\n✅ CORRECT! Model understands false beliefs!")
    else:
        print("\n❌ WRONG - Model thinks Sally knows about the move")
    
    return pred == 1

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Train model
    model, vocab = train_tom_model(epochs=50)
    
    # Test on Sally-Anne
    sally_anne_correct = test_sally_anne(model, vocab, device)
    
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    
    if sally_anne_correct:
        print("\n✅ SUCCESS!")
        print("Eden can now pass the Sally-Anne test!")
        print("Theory of Mind capability added!")
    else:
        print("\n⚠️ Model trained but needs more epochs or better architecture")
    
    print("\nTo use with Eden:")
    print("1. Load trained model: torch.load('tom_model.pth')")
    print("2. Process scenario through model")
    print("3. Model predicts if agent has false belief")
    print("4. Answer accordingly (agent's belief, not reality)")

if __name__ == "__main__":
    main()
