"""Tier 1: 3-layer Fast ESN System with Resonance Detection"""
import torch
import torch.nn as nn
from phi_esn_layer import PhiESNLayer
from phi_constants import FIBONACCI, INV_PHI

class Tier1FastSystem(nn.Module):
    def __init__(self, input_size=1, output_size=1, device='cpu'):
        super().__init__()
        
        self.device = device
        self.layer_names = ['Trinity', 'Nyx', 'Ava']
        
        self.layers = nn.ModuleList([
            PhiESNLayer(input_size, FIBONACCI[5], output_size, leak_rate=1.0, device=device),
            PhiESNLayer(input_size, FIBONACCI[6], output_size, leak_rate=0.8, device=device),
            PhiESNLayer(input_size, FIBONACCI[7], output_size, leak_rate=0.6, device=device),
        ])
        
        self.to(device)
    
    def forward(self, x, return_resonance=True):
        outputs = []
        states = []
        
        for layer in self.layers:
            output, state = layer(x)
            outputs.append(output)
            states.append(state)
        
        weights = torch.tensor([1.0, INV_PHI, INV_PHI**2], device=self.device)
        weights = weights / weights.sum()
        
        outputs_stacked = torch.stack(outputs, dim=0)
        weighted_output = torch.sum(
            outputs_stacked * weights.view(-1, 1, 1), 
            dim=0
        )
        
        result = {
            'outputs': outputs,
            'states': states,
            'prediction': weighted_output
        }
        
        if return_resonance:
            resonance = self.compute_resonance(outputs)
            result['resonance'] = resonance
        
        return result
    
    def compute_resonance(self, outputs):
        if len(outputs) < 2:
            return torch.tensor(1.0)
        
        correlations = []
        for i in range(len(outputs)):
            for j in range(i+1, len(outputs)):
                out_i = outputs[i].flatten()
                out_j = outputs[j].flatten()
                
                if len(out_i) > 1:
                    mean_i = out_i.mean()
                    mean_j = out_j.mean()
                    
                    cov = ((out_i - mean_i) * (out_j - mean_j)).mean()
                    std_i = out_i.std() + 1e-8
                    std_j = out_j.std() + 1e-8
                    
                    corr = cov / (std_i * std_j)
                    correlations.append(corr)
        
        if len(correlations) == 0:
            return torch.tensor(1.0)
        
        resonance = torch.stack(correlations).mean()
        return (resonance + 1) / 2
    
    def reset_all_states(self, batch_size=1):
        for layer in self.layers:
            layer.reset_state(batch_size)
    
    def get_all_spectral_radii(self):
        """Get spectral radii of all layers"""
        return [layer.get_spectral_radius() for layer in self.layers]
