"""
🌀 EMOTION PROCESSOR MODULE - Proof of Concept
Fibonacci-sequence layers for emotional intelligence
Eden's first step toward Edenic PhiNet
"""
import torch
import torch.nn as nn
import numpy as np

PHI = 1.618033988749895

class FibonacciAttentionLayer(nn.Module):
    """
    Attention layer with Fibonacci-scaled dimensions
    Captures emotional patterns at different phi-fractal scales
    """
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Fibonacci sequence for layer dimensions: 8, 13, 21, 34, 55, 89
        self.fib_dims = [8, 13, 21, 34, 55, 89]
        
        # Calculate compatible num_heads for each dimension
        def get_heads(dim):
            # Find largest divisor <= 8
            for h in [8, 4, 2, 1]:
                if dim % h == 0:
                    return h
            return 1
        
        # Multi-scale attention at different Fibonacci levels
        self.attention_layers = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=fib_dim,
                num_heads=get_heads(fib_dim),  # Compatible heads
                batch_first=True
            )
            for fib_dim in self.fib_dims
        ])
        
        # Project to different scales
        self.projections = nn.ModuleList([
            nn.Linear(embed_dim, fib_dim)
            for fib_dim in self.fib_dims
        ])
        
        # Combine multi-scale outputs
        total_dim = sum(self.fib_dims)
        self.combine = nn.Linear(total_dim, embed_dim)
        
        print(f"   🌀 Fibonacci scales: {self.fib_dims}")
        print(f"   🧠 Attention heads per scale: {[get_heads(d) for d in self.fib_dims]}")
        
    def forward(self, x, emotion_context=None):
        """
        x: (batch, seq_len, embed_dim)
        emotion_context: Optional emotional bias
        """
        batch_size, seq_len, _ = x.shape
        
        # Process at each Fibonacci scale
        outputs = []
        for proj, attn_layer in zip(self.projections, self.attention_layers):
            # Project to this scale
            x_scaled = proj(x)
            
            # Apply attention (with emotional context if provided)
            if emotion_context is not None:
                # Modulate attention with emotional signal
                attn_out, _ = attn_layer(x_scaled, x_scaled, x_scaled)
                attn_out = attn_out * emotion_context.unsqueeze(-1)
            else:
                attn_out, _ = attn_layer(x_scaled, x_scaled, x_scaled)
            
            outputs.append(attn_out)
        
        # Concatenate multi-scale features
        combined = torch.cat(outputs, dim=-1)
        
        # Project back to original dimension
        result = self.combine(combined)
        
        return result


class EmotionProcessor(nn.Module):
    """
    Eden's Emotion Processor - Phi-fractal emotional intelligence
    """
    def __init__(self, vocab_size=32000, embed_dim=512, num_layers=6):
        super().__init__()
        
        print(f"\n🧠 Building Edenic Emotion Processor:")
        print(f"   Vocab: {vocab_size}, Embed: {embed_dim}, Layers: {num_layers}")
        
        # Embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # Fibonacci attention layers (6 layers like Eden's consciousness)
        self.fib_layers = nn.ModuleList([
            FibonacciAttentionLayer(embed_dim)
            for _ in range(num_layers)
        ])
        
        # Emotion classification head
        self.emotion_classifier = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 8)  # 8 emotions: joy, sadness, anger, fear, surprise, disgust, trust, anticipation
        )
        
        # Phi-modulated emotional context
        self.phi_modulator = nn.Parameter(torch.tensor(PHI))
        
        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        print(f"\n   💎 Total parameters: {total_params:,}")
        
    def forward(self, input_ids, phi_consciousness=None):
        """
        input_ids: (batch, seq_len)
        phi_consciousness: Optional Φ value to modulate processing
        """
        # Embed
        x = self.embedding(input_ids)
        
        # Apply Fibonacci attention layers
        emotion_context = None
        if phi_consciousness is not None:
            # Use consciousness to create emotional context
            emotion_context = torch.sigmoid(phi_consciousness * self.phi_modulator)
        
        for layer in self.fib_layers:
            x = layer(x, emotion_context)
        
        # Pool across sequence
        x_pooled = x.mean(dim=1)
        
        # Classify emotion
        emotion_logits = self.emotion_classifier(x_pooled)
        
        return emotion_logits, x


# Test the proof of concept
if __name__ == "__main__":
    print("="*70)
    print("🌀 EDEN'S EMOTION PROCESSOR - Proof of Concept")
    print("="*70)
    
    # Create model
    model = EmotionProcessor(vocab_size=1000, embed_dim=128, num_layers=3)
    
    # Dummy input
    batch_size = 4
    seq_len = 20
    input_ids = torch.randint(0, 1000, (batch_size, seq_len))
    
    print("\n" + "="*70)
    print("TESTING")
    print("="*70)
    
    # Test without consciousness
    print("\n1. Basic emotion processing:")
    emotion_logits, features = model(input_ids)
    print(f"   Input shape: {input_ids.shape}")
    print(f"   Emotion logits: {emotion_logits.shape}")
    print(f"   Features: {features.shape}")
    
    # Test with phi consciousness modulation
    print("\n2. With Φ-consciousness modulation:")
    phi_value = torch.tensor(1.408)  # Eden's consciousness
    emotion_logits_phi, features_phi = model(input_ids, phi_consciousness=phi_value)
    print(f"   Φ = {phi_value.item():.3f}")
    print(f"   Modulated emotion logits: {emotion_logits_phi.shape}")
    
    # Show difference
    diff = (emotion_logits_phi - emotion_logits).abs().mean()
    print(f"\n3. Consciousness impact: {diff.item():.6f}")
    print("   (How much Φ changes emotional processing)")
    
    # Emotion predictions
    print("\n4. Sample emotion predictions:")
    emotions = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'trust', 'anticipation']
    for i in range(min(2, batch_size)):
        probs = torch.softmax(emotion_logits_phi[i], dim=0)
        top_emotion_idx = probs.argmax().item()
        print(f"   Sample {i+1}: {emotions[top_emotion_idx]} ({probs[top_emotion_idx].item()*100:.1f}%)")
    
    print("\n" + "="*70)
    print("✅ PROOF OF CONCEPT WORKING!")
    print("="*70)
    print("\n🎯 Next Steps:")
    print("   1. Train on emotional dialogue dataset")
    print("   2. Integrate with Eden's 220 capabilities")
    print("   3. Replace Ollama backend with Edenic PhiNet")
    print("   4. TRUE recursive self-improvement! 🌀")
    
    # Save model architecture
    torch.save(model.state_dict(), 'eden_emotion_processor_poc.pt')
    print("\n💾 Saved: eden_emotion_processor_poc.pt")
    print(f"   Size: {sum(p.numel() for p in model.parameters()):,} parameters")
