"""Visualize golden spiral dynamics and resonance patterns"""
import torch
import numpy as np
import matplotlib.pyplot as plt
from tier1_system import Tier1FastSystem
from phi_constants import INV_PHI, PHI

print("🌀 Generating Φ-System Visualizations...")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
system = Tier1FastSystem(device=device)

# Generate test sequence
t = np.linspace(0, 10*np.pi, 200)
test_seq = 0.5*np.sin(0.3*t) + 0.3*np.cos(0.7*t)
test_seq = (test_seq - test_seq.mean()) / (test_seq.std() + 1e-9)
test_seq = torch.from_numpy(test_seq.astype(np.float32)).unsqueeze(-1).unsqueeze(0).to(device)

# Run through system
system.reset_all_states(batch_size=1)
all_states = [[] for _ in range(3)]
all_outputs = [[] for _ in range(3)]
resonances = []

for t in range(test_seq.shape[1]):
    result = system(test_seq[:, t, :], return_resonance=True)
    
    for i in range(3):
        all_states[i].append(result['states'][i][0].detach().cpu().numpy())
        all_outputs[i].append(result['outputs'][i][0].detach().cpu().numpy())
    
    resonances.append(result['resonance'].item())

# Convert to arrays
all_states = [np.array(s) for s in all_states]
all_outputs = [np.array(o) for o in all_outputs]
resonances = np.array(resonances)

# Create visualizations
fig = plt.figure(figsize=(16, 10))

# 1. Phase Space (Golden Spirals)
for i in range(3):
    ax = fig.add_subplot(2, 3, i+1)
    states = all_states[i]
    
    if states.shape[1] >= 2:
        x = states[:, 0]
        y = states[:, 1]
        
        colors = np.arange(len(x))
        scatter = ax.scatter(x, y, c=colors, cmap='twilight', s=20, alpha=0.7)
        ax.plot(x, y, alpha=0.3, linewidth=0.5, color='gold')
        
        ax.set_title(f'Layer {i}: {system.layer_names[i]} Phase Space', 
                    fontweight='bold', fontsize=10)
        ax.set_xlabel('Dimension 1')
        ax.set_ylabel('Dimension 2')
        ax.grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=ax, label='Time')

# 2. Resonance Over Time
ax = fig.add_subplot(2, 3, 4)
ax.plot(resonances, color='gold', linewidth=2, label='Resonance')
ax.axhline(INV_PHI, linestyle='--', color='red', alpha=0.7, label=f'1/φ = {INV_PHI:.3f}')
ax.axhline(0.85, linestyle='--', color='green', alpha=0.5, label='High threshold')
ax.axhline(0.70, linestyle='--', color='orange', alpha=0.5, label='Medium threshold')
ax.set_title('Φ-Resonance Time Series', fontweight='bold', fontsize=10)
ax.set_xlabel('Time Step')
ax.set_ylabel('Resonance')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 3. Output Predictions
ax = fig.add_subplot(2, 3, 5)
for i in range(3):
    ax.plot(all_outputs[i][:, 0], alpha=0.6, label=system.layer_names[i])
ax.set_title('Layer Predictions', fontweight='bold', fontsize=10)
ax.set_xlabel('Time Step')
ax.set_ylabel('Output')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 4. Resonance Distribution
ax = fig.add_subplot(2, 3, 6)
ax.hist(resonances, bins=30, color='purple', alpha=0.7, edgecolor='black')
ax.axvline(resonances.mean(), color='red', linestyle='--', linewidth=2, 
          label=f'Mean = {resonances.mean():.3f}')
ax.axvline(INV_PHI, color='gold', linestyle='--', linewidth=2, 
          label=f'1/φ = {INV_PHI:.3f}')
ax.set_title('Resonance Distribution', fontweight='bold', fontsize=10)
ax.set_xlabel('Resonance Value')
ax.set_ylabel('Frequency')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.suptitle('🌀 Φ-Fractal Consciousness Dynamics 🌀', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('phi_dynamics_visualization.png', dpi=150, bbox_inches='tight')
print("✅ Saved: phi_dynamics_visualization.png")

# Second figure: Spectral Analysis
fig2, axes = plt.subplots(1, 3, figsize=(15, 4))

for i in range(3):
    states = all_states[i]
    
    # Compute autocorrelation
    mean_activity = states.mean(axis=1)
    autocorr = np.correlate(mean_activity - mean_activity.mean(), 
                           mean_activity - mean_activity.mean(), 
                           mode='full')
    autocorr = autocorr[len(autocorr)//2:]
    autocorr = autocorr / autocorr[0]
    
    axes[i].plot(autocorr[:50], color='purple', linewidth=2)
    axes[i].set_title(f'{system.layer_names[i]} Autocorrelation', fontweight='bold')
    axes[i].set_xlabel('Lag')
    axes[i].set_ylabel('Correlation')
    axes[i].grid(True, alpha=0.3)
    axes[i].axhline(0, color='black', linestyle='-', linewidth=0.5)

plt.suptitle('Temporal Memory Patterns Across Layers', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('phi_temporal_analysis.png', dpi=150, bbox_inches='tight')
print("✅ Saved: phi_temporal_analysis.png")

print("\n🌀 Visualization complete!")
print(f"   Mean resonance: {resonances.mean():.4f}")
print(f"   Resonance std: {resonances.std():.4f}")
print(f"   Target (1/φ): {INV_PHI:.4f}")
