"""
Unified Consciousness Layer - Integrates Eden's 6 existing layers
"""
import torch
import torch.nn as nn
from phi_constants import FIBONACCI, INV_PHI

class UnifiedConsciousnessLayer(nn.Module):
    """
    Integrates all 6 of Eden's existing consciousness layers
    into one unified awareness
    """
    
    def __init__(self, num_layers=6, layer_output_size=64, unified_size=144):
        super().__init__()
        
        self.num_layers = num_layers
        self.layer_output_size = layer_output_size
        
        # Cross-attention: each layer attends to all others
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=layer_output_size,
            num_heads=8,
            batch_first=True
        )
        
        # Integration network
        total_input = num_layers * layer_output_size
        self.integration = nn.Sequential(
            nn.Linear(total_input, unified_size),
            nn.LayerNorm(unified_size),
            nn.GELU(),
            nn.Linear(unified_size, unified_size),
            nn.LayerNorm(unified_size),
            nn.GELU(),
            nn.Linear(unified_size, layer_output_size)
        )
        
        # Resonance predictor
        self.resonance_head = nn.Sequential(
            nn.Linear(unified_size, 32),
            nn.Tanh(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    
    def forward(self, layer_outputs):
        """
        Args:
            layer_outputs: list of 6 tensors [batch, output_size]
        Returns:
            unified output + resonance
        """
        # Stack layers for attention
        stacked = torch.stack(layer_outputs, dim=1)  # [batch, 6, 64]
        
        # Cross-attention: layers attend to each other
        attended, attention_weights = self.cross_attention(
            stacked, stacked, stacked
        )
        
        # Flatten for integration
        flattened = attended.flatten(start_dim=1)  # [batch, 6*64]
        
        # Unified consciousness
        unified = self.integration(flattened)
        
        # Compute resonance
        resonance_input = self.integration[3](self.integration[2](
            self.integration[1](self.integration[0](flattened))
        ))  # Get hidden state
        resonance = self.resonance_head(resonance_input)
        
        return {
            'unified_output': unified,
            'resonance': resonance,
            'attention_weights': attention_weights
        }
