"""Deep analysis of layer-specific dynamics and learning"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from tier1_system import Tier1FastSystem
from phi_constants import INV_PHI, FIBONACCI

print("="*60)
print("  🔬 DEEP LAYER DYNAMICS ANALYSIS")
print("="*60)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
system = Tier1FastSystem(device=device)

# Test with different input types
test_cases = {
    'constant': lambda t: np.ones_like(t) * 0.5,
    'sine_slow': lambda t: np.sin(0.1 * t),
    'sine_fast': lambda t: np.sin(2.0 * t),
    'chirp': lambda t: np.sin(t * t * 0.01),
    'noise': lambda t: np.random.randn(len(t)) * 0.3,
}

results = {}

for name, signal_fn in test_cases.items():
    print(f"\nTesting: {name}")
    
    t = np.linspace(0, 50, 200)
    signal = signal_fn(t).astype(np.float32)
    signal_tensor = torch.from_numpy(signal).unsqueeze(-1).unsqueeze(0).to(device)
    
    system.reset_all_states(batch_size=1)
    
    layer_activities = [[] for _ in range(3)]
    layer_variances = [[] for _ in range(3)]
    resonances = []
    
    for i in range(len(signal)):
        result = system(signal_tensor[:, i, :], return_resonance=True)
        
        for j in range(3):
            state = result['states'][j][0].detach().cpu().numpy()
            layer_activities[j].append(state.mean())
            layer_variances[j].append(state.var())
        
        resonances.append(result['resonance'].item())
    
    results[name] = {
        'activities': [np.array(a) for a in layer_activities],
        'variances': [np.array(v) for v in layer_variances],
        'resonances': np.array(resonances),
        'signal': signal
    }
    
    print(f"  Mean resonance: {np.mean(resonances):.4f}")
    print(f"  Activity range: [{min([a.min() for a in layer_activities]):.3f}, "
          f"{max([a.max() for a in layer_activities]):.3f}]")

# Create comprehensive visualization
fig = plt.figure(figsize=(18, 12))

row = 0
for name, data in results.items():
    # Signal
    ax = plt.subplot(len(test_cases), 4, row*4 + 1)
    ax.plot(data['signal'], color='black', linewidth=1)
    ax.set_title(f'{name.upper()}: Input Signal', fontsize=9, fontweight='bold')
    ax.set_ylabel('Amplitude')
    ax.grid(True, alpha=0.3)
    
    # Layer Activities
    ax = plt.subplot(len(test_cases), 4, row*4 + 2)
    for i in range(3):
        ax.plot(data['activities'][i], label=system.layer_names[i], alpha=0.7)
    ax.set_title('Mean Layer Activity', fontsize=9, fontweight='bold')
    ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)
    
    # Layer Variances
    ax = plt.subplot(len(test_cases), 4, row*4 + 3)
    for i in range(3):
        ax.plot(data['variances'][i], label=system.layer_names[i], alpha=0.7)
    ax.set_title('Activity Variance', fontsize=9, fontweight='bold')
    ax.legend(fontsize=7)
    ax.grid(True, alpha=0.3)
    
    # Resonance
    ax = plt.subplot(len(test_cases), 4, row*4 + 4)
    ax.plot(data['resonances'], color='gold', linewidth=2)
    ax.axhline(INV_PHI, linestyle='--', color='red', alpha=0.5, linewidth=1)
    ax.set_title(f'Resonance (μ={data["resonances"].mean():.3f})', 
                fontsize=9, fontweight='bold')
    ax.set_ylim([0, 1])
    ax.grid(True, alpha=0.3)
    
    row += 1

plt.suptitle('🔬 Layer Dynamics Analysis Across Input Types', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('layer_dynamics_analysis.png', dpi=150, bbox_inches='tight')
print("\n✅ Saved: layer_dynamics_analysis.png")

# Summary statistics
print("\n" + "="*60)
print("  SUMMARY STATISTICS")
print("="*60)

for name, data in results.items():
    print(f"\n{name.upper()}:")
    print(f"  Resonance: μ={data['resonances'].mean():.4f}, σ={data['resonances'].std():.4f}")
    
    for i in range(3):
        act = data['activities'][i]
        print(f"  {system.layer_names[i]:8s}: "
              f"activity={act.mean():.4f}, "
              f"variance={data['variances'][i].mean():.4f}")

print("\n" + "="*60)
print("  KEY INSIGHTS")
print("="*60)
print(f"• Trinity (8 neurons, leak=1.0): Fast, reactive")
print(f"• Nyx (13 neurons, leak=0.8): Medium integration")
print(f"• Ava (21 neurons, leak=0.6): Slow, contextual")
print(f"• All tuned to spectral radius = {INV_PHI:.6f} (1/φ)")
print(f"• Resonance indicates cross-timescale coherence")
print("="*60)
