"""Single Echo State Network layer tuned to φ"""
import torch
import torch.nn as nn
import numpy as np
from phi_constants import INV_PHI

def spectral_radius(W):
    """Compute spectral radius (largest eigenvalue magnitude)"""
    if isinstance(W, torch.Tensor):
        W = W.detach().cpu().numpy()
    eigenvalues = np.linalg.eigvals(W)
    return float(np.max(np.abs(eigenvalues)))

def rescale_to_target_sr(W, target_sr):
    """Rescale matrix to target spectral radius"""
    current_sr = spectral_radius(W)
    if current_sr == 0:
        return W
    return W * (target_sr / current_sr)

class PhiESNLayer(nn.Module):
    """Echo State Network layer with fixed spectral radius = 1/φ"""
    
    def __init__(self, input_size, reservoir_size, output_size, 
                 leak_rate=1.0, sparsity=0.1, device='cpu'):
        super().__init__()
        
        self.input_size = input_size
        self.reservoir_size = reservoir_size
        self.output_size = output_size
        self.leak_rate = leak_rate
        self.device = device
        
        # Input weights (random, fixed)
        self.W_in = nn.Parameter(
            torch.randn(reservoir_size, input_size) * 0.5,
            requires_grad=False
        )
        
        # Recurrent weights (sparse, tuned to 1/φ)
        W = torch.randn(reservoir_size, reservoir_size)
        mask = torch.rand(reservoir_size, reservoir_size) < sparsity
        W = W * mask.float()
        
        # Scale to target spectral radius
        W_np = rescale_to_target_sr(W.numpy(), INV_PHI)
        self.W_rec = nn.Parameter(
            torch.from_numpy(W_np).float(),
            requires_grad=False
        )
        
        # Readout weights (trainable)
        self.readout = nn.Linear(reservoir_size, output_size)
        
        # State
        self.register_buffer('state', torch.zeros(1, reservoir_size))
        
        self.to(device)
    
    def forward(self, x):
        """Process input: x is [batch, input_size]"""
        batch_size = x.shape[0]
        
        if self.state.shape[0] != batch_size:
            self.state = torch.zeros(batch_size, self.reservoir_size).to(self.device)
        
        # Compute new state
        input_contrib = torch.matmul(x, self.W_in.t())
        recurrent_contrib = torch.matmul(self.state, self.W_rec.t())
        
        new_state = torch.tanh(input_contrib + recurrent_contrib)
        self.state = (1 - self.leak_rate) * self.state + self.leak_rate * new_state
        
        # Output
        output = self.readout(self.state)
        
        return output, self.state
    
    def reset_state(self, batch_size=1):
        """Reset reservoir state"""
        self.state = torch.zeros(batch_size, self.reservoir_size).to(self.device)
    
    def get_spectral_radius(self):
        """Get current spectral radius"""
        return spectral_radius(self.W_rec)

