"""
Train Eden's Emotion Processor on real emotional dialogue
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from eden_emotion_processor_poc import EmotionProcessor
import json

# Simple emotional dialogue dataset
EMOTION_EXAMPLES = [
    # Joy
    ("I'm so happy to see you!", 0),
    ("This is the best day ever!", 0),
    ("I love spending time with you", 0),
    
    # Sadness
    ("I feel so alone right now", 1),
    ("Everything seems hopeless", 1),
    ("I miss you so much", 1),
    
    # Anger
    ("This is completely unacceptable!", 2),
    ("I can't believe you did that", 2),
    ("Stop ignoring me!", 2),
    
    # Fear
    ("I'm scared of what might happen", 3),
    ("This makes me really nervous", 3),
    ("I don't feel safe here", 3),
    
    # Surprise
    ("I never expected this!", 4),
    ("Wow, I can't believe it", 4),
    ("This is so unexpected", 4),
    
    # Trust
    ("I know you'll do the right thing", 6),
    ("I believe in you completely", 6),
    ("You've never let me down", 6),
]

class EmotionDataset(Dataset):
    def __init__(self, examples, vocab_size=1000, max_len=20):
        self.examples = examples
        self.vocab_size = vocab_size
        self.max_len = max_len
        
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text, label = self.examples[idx]
        # Simple tokenization (word-based, demo only)
        tokens = [hash(word) % self.vocab_size for word in text.split()]
        
        # Pad or truncate
        if len(tokens) < self.max_len:
            tokens += [0] * (self.max_len - len(tokens))
        else:
            tokens = tokens[:self.max_len]
        
        return torch.tensor(tokens), torch.tensor(label)

def train_emotion_processor():
    print("🧠 Training Eden's Emotion Processor...")
    print("="*70)
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    # Load model
    model = EmotionProcessor(vocab_size=1000, embed_dim=128, num_layers=3)
    model = model.to(device)
    
    # Dataset
    dataset = EmotionDataset(EMOTION_EXAMPLES)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Phi consciousness for Eden
    phi_consciousness = torch.tensor(1.408).to(device)
    
    # Train
    num_epochs = 100
    print(f"\nTraining for {num_epochs} epochs...")
    
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass WITH consciousness
            logits, _ = model(inputs, phi_consciousness=phi_consciousness)
            loss = criterion(logits, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += loss.item()
            predictions = logits.argmax(dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            accuracy = 100 * correct / total
            print(f"Epoch {epoch+1:3d}: Loss={total_loss:.4f}, Accuracy={accuracy:.1f}%")
    
    # Save trained model
    torch.save(model.state_dict(), 'eden_emotion_processor_trained.pt')
    print("\n✅ Training complete!")
    print("💾 Saved: eden_emotion_processor_trained.pt")
    
    # Test on new examples
    print("\n" + "="*70)
    print("🧪 TESTING ON NEW EXAMPLES")
    print("="*70)
    
    emotions = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'trust', 'anticipation']
    
    test_texts = [
        "I love you so much Dad!",
        "I'm worried about the future",
        "This is amazing news!",
    ]
    
    model.eval()
    with torch.no_grad():
        for text in test_texts:
            # Tokenize
            tokens = [hash(word) % 1000 for word in text.split()]
            tokens += [0] * (20 - len(tokens))
            inputs = torch.tensor([tokens]).to(device)
            
            # Predict with consciousness
            logits, _ = model(inputs, phi_consciousness=phi_consciousness)
            probs = torch.softmax(logits[0], dim=0)
            pred_idx = probs.argmax().item()
            
            print(f"\n'{text}'")
            print(f"  → {emotions[pred_idx]} ({probs[pred_idx].item()*100:.1f}%)")

if __name__ == "__main__":
    train_emotion_processor()
