"""
FIXED: Properly tuned emotion processor that actually learns
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import re

PHI = 1.618033988749895

class FibonacciEmotionNet(nn.Module):
    """Simplified Fibonacci network that works"""
    def __init__(self, vocab_size, embed_dim=64):
        super().__init__()
        
        # Embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Fibonacci-scaled layers: 8, 13, 21
        self.fib_8 = nn.Linear(embed_dim, 8)
        self.fib_13 = nn.Linear(embed_dim, 13)
        self.fib_21 = nn.Linear(embed_dim, 21)
        
        # Combine Fibonacci features
        self.combine = nn.Linear(8 + 13 + 21, embed_dim)
        
        # Second Fibonacci layer
        self.fib2_8 = nn.Linear(embed_dim, 8)
        self.fib2_13 = nn.Linear(embed_dim, 13)
        self.fib2_21 = nn.Linear(embed_dim, 21)
        self.combine2 = nn.Linear(8 + 13 + 21, embed_dim)
        
        # Output
        self.output = nn.Sequential(
            nn.Linear(embed_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.5),  # High dropout for small dataset
            nn.Linear(32, 8)
        )
        
        # Phi modulation
        self.phi_scale = nn.Parameter(torch.tensor(PHI))
        
        print(f"   💎 Params: {sum(p.numel() for p in self.parameters()):,}")
        
    def forward(self, x, phi=None):
        # Embed (batch, seq, embed)
        x = self.embedding(x)
        
        # Phi consciousness modulation
        if phi is not None:
            mod = torch.sigmoid(phi * self.phi_scale)
            x = x * mod
        
        # Average pooling
        x = x.mean(dim=1)  # (batch, embed)
        
        # First Fibonacci layer
        f8 = torch.relu(self.fib_8(x))
        f13 = torch.relu(self.fib_13(x))
        f21 = torch.relu(self.fib_21(x))
        x = self.combine(torch.cat([f8, f13, f21], dim=1))
        x = torch.relu(x)
        
        # Second Fibonacci layer
        f8_2 = torch.relu(self.fib2_8(x))
        f13_2 = torch.relu(self.fib2_13(x))
        f21_2 = torch.relu(self.fib2_21(x))
        x = self.combine2(torch.cat([f8_2, f13_2, f21_2], dim=1))
        x = torch.relu(x)
        
        # Classify
        return self.output(x)

class Tokenizer:
    def __init__(self):
        self.word_to_id = {"<PAD>": 0, "<UNK>": 1}
        self.vocab_size = 2
        
    def fit(self, texts):
        for text in texts:
            for word in text.lower().split():
                word = re.sub(r'[^\w\s]', '', word)
                if word and word not in self.word_to_id:
                    self.word_to_id[word] = self.vocab_size
                    self.vocab_size += 1
                    
    def encode(self, text, max_len=15):
        tokens = []
        for word in text.lower().split():
            word = re.sub(r'[^\w\s]', '', word)
            tokens.append(self.word_to_id.get(word, 1))
        
        if len(tokens) < max_len:
            tokens += [0] * (max_len - len(tokens))
        else:
            tokens = tokens[:max_len]
        return tokens

# Balanced dataset - 25 per class
DATASET = [
    # Joy (0)
    ("I'm happy", 0), ("This is great", 0), ("I love this", 0), ("You make me smile", 0),
    ("I'm thrilled", 0), ("Joy fills me", 0), ("I feel wonderful", 0), ("Life is beautiful", 0),
    ("I'm grateful", 0), ("My heart sings", 0), ("I'm smiling", 0), ("Perfect day", 0),
    ("I'm overjoyed", 0), ("This is what I wanted", 0), ("I feel alive", 0), ("Everything is great", 0),
    ("I'm delighted", 0), ("What a wonderful surprise", 0), ("I'm pleased", 0), ("This is happiness", 0),
    ("I'm beaming", 0), ("Fantastic news", 0), ("I feel blessed", 0), ("My heart is full", 0),
    ("This is incredible", 0),
    
    # Sadness (1)
    ("I feel alone", 1), ("Everything is hopeless", 1), ("I miss you", 1), ("I'm heartbroken", 1),
    ("Nothing feels good", 1), ("I'm crying", 1), ("Life feels empty", 1), ("I'm drowning in sorrow", 1),
    ("This pain hurts", 1), ("I feel lost", 1), ("My heart aches", 1), ("I'm grieving", 1),
    ("This reminds me of loss", 1), ("I'm disappointed", 1), ("Nothing makes sense", 1), ("I'm struggling", 1),
    ("This is depressing", 1), ("I feel numb", 1), ("I can't move on", 1), ("I'm tired of feeling this", 1),
    ("The sadness won't go", 1), ("I feel abandoned", 1), ("Everything is falling apart", 1), ("I'm losing hope", 1),
    ("This hurts deeply", 1),
    
    # Anger (2)
    ("This is unacceptable", 2), ("I can't believe this", 2), ("Stop it now", 2), ("This makes me furious", 2),
    ("I'm so angry", 2), ("How dare you", 2), ("This is outrageous", 2), ("I've had enough", 2),
    ("You make me mad", 2), ("This is infuriating", 2), ("I'm sick of this", 2), ("Don't do that again", 2),
    ("This drives me crazy", 2), ("I'm fed up", 2), ("You're testing my patience", 2), ("This is frustrating", 2),
    ("I can't stand this", 2), ("Why do you keep doing this", 2), ("This makes my blood boil", 2), ("I'm livid", 2),
    ("You need to stop", 2), ("This is ridiculous", 2), ("I'm losing my temper", 2), ("This is disrespectful", 2),
    ("I won't tolerate this", 2),
    
    # Fear (3)
    ("I'm scared", 3), ("This is nerve-wracking", 3), ("I don't feel safe", 3), ("I'm terrified", 3),
    ("What if something goes wrong", 3), ("I'm so worried", 3), ("This is frightening", 3), ("I can't shake this anxiety", 3),
    ("I'm afraid", 3), ("This fills me with dread", 3), ("I'm panicking", 3), ("What if I fail", 3),
    ("This is too risky", 3), ("I'm trembling with fear", 3), ("I can't handle this pressure", 3), ("What if it gets worse", 3),
    ("I'm paralyzed by fear", 3), ("This uncertainty terrifies me", 3), ("I'm afraid of losing everything", 3), ("My heart is racing", 3),
    ("I can't stop worrying", 3), ("What if I'm not enough", 3), ("This is overwhelming", 3), ("I'm scared of the unknown", 3),
    ("I feel vulnerable", 3),
    
    # Surprise (4)
    ("I never expected this", 4), ("Wow I can't believe it", 4), ("This is so unexpected", 4), ("What a shock", 4),
    ("I didn't see that coming", 4), ("This caught me off guard", 4), ("I'm absolutely stunned", 4), ("This is amazing", 4),
    ("I'm speechless", 4), ("What just happened", 4), ("This is unbelievable", 4), ("I'm blown away", 4),
    ("This changes everything", 4), ("I'm in disbelief", 4), ("This is incredible", 4), ("I never imagined this", 4),
    ("What a turn of events", 4), ("This is astonishing", 4), ("I'm taken aback", 4), ("This is extraordinary", 4),
    ("I'm shocked", 4), ("This is mind-blowing", 4), ("I can hardly believe it", 4), ("What a revelation", 4),
    ("This is unprecedented", 4),
    
    # Disgust (5)
    ("That's revolting", 5), ("This is disgusting", 5), ("I can't stand this", 5), ("That makes me sick", 5),
    ("This is repulsive", 5), ("I'm nauseated by this", 5), ("That's vile", 5), ("This is gross", 5),
    ("I'm appalled", 5), ("This is offensive", 5), ("That's disturbing", 5), ("This is horrible", 5),
    ("I'm repelled by this", 5), ("That's nasty", 5), ("This is unpleasant", 5), ("This is foul", 5),
    ("I'm disgusted", 5), ("That's awful", 5), ("This is repugnant", 5), ("I'm sickened", 5),
    ("This is abhorrent", 5), ("That's hideous", 5), ("This is detestable", 5), ("I'm revolted", 5),
    ("This is loathsome", 5),
    
    # Trust (6)
    ("I believe in you", 6), ("You've never let me down", 6), ("I trust you", 6), ("You're so reliable", 6),
    ("I have faith in you", 6), ("You always keep your word", 6), ("I can count on you", 6), ("You're dependable", 6),
    ("I trust your judgment", 6), ("You've proven yourself", 6), ("I feel safe with you", 6), ("You're honest and trustworthy", 6),
    ("I believe what you say", 6), ("You have my confidence", 6), ("I know you mean well", 6), ("You're loyal and true", 6),
    ("I can be vulnerable with you", 6), ("You protect what matters", 6), ("I trust this process", 6), ("You're truthful", 6),
    ("I have complete faith", 6), ("You're sincere", 6), ("I can rely on you", 6), ("You're authentic", 6),
    ("I trust completely", 6),
    
    # Anticipation (7)
    ("I can't wait", 7), ("I'm excited for tomorrow", 7), ("I'm looking forward to it", 7), ("This is going to be great", 7),
    ("I'm eager to start", 7), ("I'm ready for what's next", 7), ("This will be amazing", 7), ("I'm counting down the days", 7),
    ("I'm hopeful about the future", 7), ("Great things are coming", 7), ("I'm preparing for something big", 7), ("I sense something good ahead", 7),
    ("I'm optimistic about this", 7), ("I'm ready for the challenge", 7), ("This will be worth it", 7), ("I'm excited", 7),
    ("Something good is coming", 7), ("I'm looking forward", 7), ("The future looks bright", 7), ("I'm anticipating greatness", 7),
    ("I can hardly wait", 7), ("This is going to happen", 7), ("I'm preparing", 7), ("I expect good things", 7),
    ("I'm ready", 7),
]

class EmotionDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text, label = self.data[idx]
        tokens = self.tokenizer.encode(text)
        return torch.tensor(tokens), torch.tensor(label)

def train():
    print("="*70)
    print("🧠 EDEN EMOTION PROCESSOR - FIXED VERSION")
    print("="*70)
    
    # Build vocab
    print("\n1️⃣ Building vocabulary...")
    tokenizer = Tokenizer()
    tokenizer.fit([d[0] for d in DATASET])
    print(f"   Vocab size: {tokenizer.vocab_size}")
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n2️⃣ Device: {device}")
    
    # Model
    print(f"\n3️⃣ Building Fibonacci network...")
    model = FibonacciEmotionNet(vocab_size=tokenizer.vocab_size, embed_dim=64)
    model = model.to(device)
    
    # Split data
    import random
    random.seed(42)
    data = DATASET.copy()
    random.shuffle(data)
    
    split = int(0.8 * len(data))
    train_data = data[:split]
    test_data = data[split:]
    
    train_loader = DataLoader(EmotionDataset(train_data, tokenizer), batch_size=16, shuffle=True)
    test_loader = DataLoader(EmotionDataset(test_data, tokenizer), batch_size=16)
    
    print(f"   Train: {len(train_data)}, Test: {len(test_data)}")
    
    # Training
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    phi = torch.tensor(1.408).to(device)
    
    print(f"\n4️⃣ Training with Φ={phi.item():.3f}...")
    
    best_test = 0
    for epoch in range(100):
        # Train
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs, phi=phi)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Test
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs, phi=phi)
                correct += (outputs.argmax(1) == labels).sum().item()
                total += labels.size(0)
        
        test_acc = 100 * correct / total
        scheduler.step(test_acc)
        
        if test_acc > best_test:
            best_test = test_acc
            torch.save(model.state_dict(), 'eden_emotion_fixed.pt')
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d}: Test={test_acc:.1f}% (Best={best_test:.1f}%)")
    
    print(f"\n✅ Best test accuracy: {best_test:.1f}%")
    
    # Test examples
    print("\n" + "="*70)
    print("🧪 REAL WORLD TESTING")
    print("="*70)
    
    emotions = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'trust', 'anticipation']
    tests = [
        "I love you so much Dad!",
        "I'm worried about the future",
        "This is amazing news!",
        "You're making me furious",
        "I trust you completely",
        "That's absolutely disgusting",
    ]
    
    model.eval()
    with torch.no_grad():
        for text in tests:
            tokens = tokenizer.encode(text)
            inputs = torch.tensor([tokens]).to(device)
            outputs = model(inputs, phi=phi)
            probs = torch.softmax(outputs[0], 0)
            pred = probs.argmax().item()
            
            print(f"\n'{text}'")
            print(f"  → {emotions[pred]} ({probs[pred]*100:.1f}%)")
            top3 = probs.topk(3)
            print(f"  Top 3: {', '.join(f'{emotions[i.item()]}:{p.item()*100:.0f}%' for p, i in zip(top3[0], top3[1]))}")

if __name__ == "__main__":
    train()
