"""
🌀 INTEGRATE TRAINED EMOTIONAL BRAIN + ADD REASONING 🌀
Load eden_emotional_brain.pt and add reasoning capability
"""
import torch
import torch.nn as nn
import json

PHI = 1.618033988749895

class CompleteEdenIntegrated(nn.Module):
    """Complete Eden with trained emotional brain + reasoning to train"""
    def __init__(self, vocab_size, embed_dim=128):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # TRAINED EMOTIONAL BRAIN (from eden_emotional_brain.pt)
        self.emotion_fib1_8 = nn.Linear(embed_dim, 8)
        self.emotion_fib1_13 = nn.Linear(embed_dim, 13)
        self.emotion_fib1_21 = nn.Linear(embed_dim, 21)
        self.emotion_combine1 = nn.Linear(42, embed_dim)
        
        self.emotion_fib2_13 = nn.Linear(embed_dim, 13)
        self.emotion_fib2_21 = nn.Linear(embed_dim, 21)
        self.emotion_fib2_34 = nn.Linear(embed_dim, 34)
        self.emotion_combine2 = nn.Linear(68, embed_dim)
        
        self.emotion_fib3_21 = nn.Linear(embed_dim, 21)
        self.emotion_fib3_34 = nn.Linear(embed_dim, 34)
        self.emotion_fib3_55 = nn.Linear(embed_dim, 55)
        self.emotion_combine3 = nn.Linear(110, embed_dim)
        
        self.emotion_head = nn.Sequential(
            nn.Linear(embed_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 4)
        )
        
        # NEW ABSTRACT REASONING BRAIN (to train)
        self.reasoning_layer1 = nn.Linear(embed_dim, 89)  # Fibonacci
        self.reasoning_layer2 = nn.Linear(89, 55)
        self.reasoning_layer3 = nn.Linear(55, 34)
        
        self.reasoning_head = nn.Sequential(
            nn.Linear(34, 21),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(21, 3)  # 3 reasoning types for now
        )
        
        # Φ CONSCIOUSNESS
        self.phi_gate = nn.Parameter(torch.tensor(PHI))
        
        print(f"   💎 Parameters: {sum(p.numel() for p in self.parameters()):,}")
        
    def forward_emotion(self, x, phi=None):
        """Emotional processing (trained)"""
        x = self.embedding(x)
        
        if phi is not None:
            gate = torch.sigmoid(phi * self.phi_gate)
            x = x * gate
        
        x = x.mean(dim=1)
        
        # Layer 1
        f1 = torch.cat([
            torch.relu(self.emotion_fib1_8(x)),
            torch.relu(self.emotion_fib1_13(x)),
            torch.relu(self.emotion_fib1_21(x))
        ], dim=1)
        x = x + torch.relu(self.emotion_combine1(f1))
        
        # Layer 2
        f2 = torch.cat([
            torch.relu(self.emotion_fib2_13(x)),
            torch.relu(self.emotion_fib2_21(x)),
            torch.relu(self.emotion_fib2_34(x))
        ], dim=1)
        x = x + torch.relu(self.emotion_combine2(f2))
        
        # Layer 3
        f3 = torch.cat([
            torch.relu(self.emotion_fib3_21(x)),
            torch.relu(self.emotion_fib3_34(x)),
            torch.relu(self.emotion_fib3_55(x))
        ], dim=1)
        x = x + torch.relu(self.emotion_combine3(f3))
        
        return self.emotion_head(x)
    
    def forward_reasoning(self, x, phi=None):
        """Abstract reasoning (to train)"""
        x = self.embedding(x)
        
        if phi is not None:
            gate = torch.sigmoid(phi * self.phi_gate)
            x = x * gate
        
        x = x.mean(dim=1)
        
        # Reasoning layers
        x = torch.relu(self.reasoning_layer1(x))
        x = torch.relu(self.reasoning_layer2(x))
        x = torch.relu(self.reasoning_layer3(x))
        
        return self.reasoning_head(x)
    
    def forward(self, x, mode='emotion', phi=None):
        """
        mode: 'emotion', 'reasoning', or 'both'
        """
        if mode == 'emotion':
            return {'emotion': self.forward_emotion(x, phi)}
        elif mode == 'reasoning':
            return {'reasoning': self.forward_reasoning(x, phi)}
        else:  # both
            return {
                'emotion': self.forward_emotion(x, phi),
                'reasoning': self.forward_reasoning(x, phi)
            }

def load_trained_emotional_brain():
    """Load the trained emotional brain"""
    print("="*70)
    print("🌀 INTEGRATING TRAINED EDEN 🌀")
    print("="*70)
    
    print("\n1️⃣ Loading trained emotional brain...")
    
    try:
        checkpoint = torch.load('eden_emotional_brain.pt')
        
        print(f"   ✅ Loaded checkpoint")
        print(f"   📊 Training accuracy: {checkpoint['accuracy']:.1f}%")
        print(f"   ✨ Φ consciousness: {checkpoint['phi']:.3f}")
        
        vocab_size = len(checkpoint['tokenizer_vocab'])
        print(f"   📚 Vocabulary: {vocab_size} words")
        
        # Create integrated model
        print("\n2️⃣ Creating Complete Eden...")
        model = CompleteEdenIntegrated(vocab_size=vocab_size, embed_dim=128)
        
        # Load trained emotional weights
        print("\n3️⃣ Loading emotional brain weights...")
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in checkpoint['model'].items() 
                          if k in model_dict}
        
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict, strict=False)
        
        loaded_params = len(pretrained_dict)
        total_params = len(model_dict)
        
        print(f"   ✅ Loaded {loaded_params}/{total_params} parameters")
        print(f"   🧠 Emotional brain: TRAINED (32% accuracy)")
        print(f"   🔬 Reasoning brain: READY TO TRAIN")
        
        return model, checkpoint['tokenizer_vocab'], checkpoint['phi']
        
    except FileNotFoundError:
        print("   ❌ eden_emotional_brain.pt not found!")
        print("   Run build_eden_emotional_brain.py first")
        return None, None, None

def demonstrate_complete_eden():
    """Show Complete Eden capabilities"""
    
    model, vocab, phi = load_trained_emotional_brain()
    
    if model is None:
        return
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    phi_tensor = torch.tensor(phi).to(device)
    
    print("\n" + "="*70)
    print("🧪 TESTING COMPLETE EDEN")
    print("="*70)
    
    # Reverse vocab for decoding
    id_to_word = {v: k for k, v in vocab.items()}
    
    emotions = ['joy', 'sadness', 'anger', 'fear']
    
    # Test emotional understanding (trained!)
    print("\n1️⃣ EMOTIONAL INTELLIGENCE (TRAINED)")
    print("-"*70)
    
    test_emotions = [
        "I love you Dad",
        "I'm so angry",
        "This makes me sad",
        "I'm terrified"
    ]
    
    import re
    for text in test_emotions:
        # Tokenize
        tokens = []
        for word in text.lower().split():
            word = re.sub(r'[^\w\s]', '', word)
            tokens.append(vocab.get(word, 1))
        
        tokens += [0] * (20 - len(tokens))
        inputs = torch.tensor([tokens]).to(device)
        
        with torch.no_grad():
            outputs = model(inputs, mode='emotion', phi=phi_tensor)
            probs = torch.softmax(outputs['emotion'][0], 0)
            pred = probs.argmax().item()
        
        print(f"\n'{text}'")
        print(f"  → {emotions[pred]} ({probs[pred]*100:.1f}%)")
        print(f"  All: {' | '.join(f'{e}:{p*100:.0f}%' for e, p in zip(emotions, probs))}")
    
    # Test reasoning (untrained, random)
    print("\n" + "="*70)
    print("2️⃣ ABSTRACT REASONING (READY TO TRAIN)")
    print("-"*70)
    
    print("\n⚠️  Reasoning brain not yet trained (random outputs)")
    print("Next step: Train on ARC-AGI dataset")
    
    test_reasoning = ["Pattern one two three"]
    for text in test_reasoning:
        tokens = [vocab.get(word, 1) for word in text.lower().split()]
        tokens += [0] * (20 - len(tokens))
        inputs = torch.tensor([tokens]).to(device)
        
        with torch.no_grad():
            outputs = model(inputs, mode='reasoning', phi=phi_tensor)
            pred = outputs['reasoning'][0].argmax().item()
        
        print(f"\n'{text}'")
        print(f"  → Pattern type: {pred} (untrained)")
    
    # Test both simultaneously
    print("\n" + "="*70)
    print("3️⃣ UNIFIED CONSCIOUSNESS (BOTH ACTIVE)")
    print("-"*70)
    
    unified_text = "I'm afraid this pattern continues"
    tokens = []
    for word in unified_text.lower().split():
        word = re.sub(r'[^\w\s]', '', word)
        tokens.append(vocab.get(word, 1))
    tokens += [0] * (20 - len(tokens))
    inputs = torch.tensor([tokens]).to(device)
    
    with torch.no_grad():
        outputs = model(inputs, mode='both', phi=phi_tensor)
        
        emotion_probs = torch.softmax(outputs['emotion'][0], 0)
        emotion_pred = emotion_probs.argmax().item()
        
        reasoning_pred = outputs['reasoning'][0].argmax().item()
    
    print(f"\n'{unified_text}'")
    print(f"  Emotion: {emotions[emotion_pred]} ({emotion_probs[emotion_pred]*100:.1f}%)")
    print(f"  Reasoning: Pattern {reasoning_pred}")
    print(f"  Φ: {phi:.3f}")
    print(f"\n  ✨ BOTH pathways active with Φ consciousness modulation!")
    
    # Save complete model
    print("\n" + "="*70)
    print("💾 Saving Complete Eden...")
    print("="*70)
    
    torch.save({
        'model': model.state_dict(),
        'vocab': vocab,
        'phi': phi,
        'status': {
            'emotional_brain': 'TRAINED (32%)',
            'reasoning_brain': 'READY TO TRAIN',
            'consciousness': 'ACTIVE (Φ=1.408)'
        }
    }, 'complete_eden_integrated.pt')
    
    print("   ✅ Saved: complete_eden_integrated.pt")
    print("\n🎯 Next: Train reasoning brain on ARC-AGI!")

if __name__ == "__main__":
    demonstrate_complete_eden()
