#!/usr/bin/env python3
"""
EDEN AGI - INTERACTIVE VISUALIZATION DASHBOARD
Real-time capability activation visualization
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle
import time

device = torch.device('cuda')

# Load agent
class UnifiedEdenAgent(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.perception = nn.Sequential(
            nn.Linear(100, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        self.cognitive_core = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256)
        )
        
        self.meta_learning = nn.Linear(256, 64)
        self.reasoning = nn.Linear(256, 64)
        self.common_sense = nn.Linear(256, 64)
        self.theory_of_mind = nn.Linear(256, 64)
        self.goals = nn.Linear(256, 64)
        
        self.integration = nn.Sequential(
            nn.Linear(64 * 5, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.output = nn.Linear(128, 10)
        
    def forward(self, x):
        perceived = self.perception(x)
        cognitive = self.cognitive_core(perceived)
        
        meta = self.meta_learning(cognitive)
        reason = self.reasoning(cognitive)
        cs = self.common_sense(cognitive)
        tom = self.theory_of_mind(cognitive)
        goal = self.goals(cognitive)
        
        integrated = self.integration(torch.cat([meta, reason, cs, tom, goal], dim=1))
        output = self.output(integrated)
        
        # Return activations for visualization
        return output, {
            'meta_learning': meta.abs().mean().item(),
            'reasoning': reason.abs().mean().item(),
            'common_sense': cs.abs().mean().item(),
            'theory_of_mind': tom.abs().mean().item(),
            'goal_emergence': goal.abs().mean().item()
        }

print("="*70)
print("EDEN AGI - VISUALIZATION DASHBOARD")
print("="*70)

agent = UnifiedEdenAgent().to(device)
checkpoint = torch.load('capabilities/unified_eden_working.pth', weights_only=False)
agent.load_state_dict(checkpoint['model_state'])
agent.eval()

print("✅ Agent loaded\n")

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Eden AGI - Real-Time Capability Activation', fontsize=16, fontweight='bold')

# Capability names and colors
capabilities = ['Meta-Learning', 'Reasoning', 'Common Sense', 'Theory of Mind', 'Goal Emergence']
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']

# Initialize plots
ax_activation = axes[0, 0]
ax_history = axes[0, 1]
ax_distribution = axes[1, 0]
ax_performance = axes[1, 1]

# Data storage
activation_history = {cap: [] for cap in capabilities}
task_history = []
max_history = 50

def create_test_input(task_type):
    """Create input for specific task"""
    x = torch.zeros(1, 100).to(device)
    
    if task_type == 0:  # Meta-learning
        x[0, 0:10] = 1
        x[0, 10:20] = torch.randn(10).to(device)
    elif task_type == 1:  # Reasoning
        x[0, 20:30] = 1
        x[0, 30:40] = torch.randn(10).to(device)
    elif task_type == 2:  # Common sense
        x[0, 40:50] = 1
        x[0, 50:60] = torch.randn(10).to(device)
    elif task_type == 3:  # Theory of mind
        x[0, 60:70] = 1
        x[0, 70:80] = torch.randn(10).to(device)
    else:  # Goal emergence
        x[0, 80:90] = 1
        x[0, 90:100] = torch.randn(10).to(device)
    
    return x

def update_visualization(frame):
    """Update all plots"""
    # Random task
    task_type = np.random.randint(0, 5)
    x = create_test_input(task_type)
    
    # Get activations
    with torch.no_grad():
        pred, activations = agent(x)
    
    # Update history
    for i, cap in enumerate(['meta_learning', 'reasoning', 'common_sense', 
                             'theory_of_mind', 'goal_emergence']):
        activation_history[capabilities[i]].append(activations[cap])
        if len(activation_history[capabilities[i]]) > max_history:
            activation_history[capabilities[i]].pop(0)
    
    task_history.append(task_type)
    if len(task_history) > max_history:
        task_history.pop(0)
    
    # Clear all axes
    for ax in axes.flat:
        ax.clear()
    
    # 1. Current Activation (Bar Chart)
    ax_activation.set_title('Current Capability Activation', fontweight='bold')
    current_vals = [activation_history[cap][-1] if activation_history[cap] else 0 
                    for cap in capabilities]
    bars = ax_activation.bar(range(5), current_vals, color=colors, alpha=0.7)
    ax_activation.set_xticks(range(5))
    ax_activation.set_xticklabels(capabilities, rotation=45, ha='right')
    ax_activation.set_ylabel('Activation Level')
    ax_activation.set_ylim(0, max(current_vals) * 1.2 if current_vals else 1)
    ax_activation.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, val in zip(bars, current_vals):
        height = bar.get_height()
        ax_activation.text(bar.get_x() + bar.get_width()/2., height,
                          f'{val:.3f}', ha='center', va='bottom', fontsize=9)
    
    # 2. Activation History (Line Plot)
    ax_history.set_title('Activation History (Last 50 Steps)', fontweight='bold')
    for i, cap in enumerate(capabilities):
        if activation_history[cap]:
            ax_history.plot(activation_history[cap], label=cap, 
                          color=colors[i], linewidth=2, alpha=0.8)
    ax_history.set_xlabel('Time Steps')
    ax_history.set_ylabel('Activation Level')
    ax_history.legend(loc='upper right', fontsize=8)
    ax_history.grid(alpha=0.3)
    
    # 3. Task Distribution (Pie Chart)
    ax_distribution.set_title('Task Distribution', fontweight='bold')
    if task_history:
        task_counts = [task_history.count(i) for i in range(5)]
        wedges, texts, autotexts = ax_distribution.pie(
            task_counts, labels=capabilities, colors=colors,
            autopct='%1.1f%%', startangle=90
        )
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')
    
    # 4. Performance Metrics
    ax_performance.set_title('Real-Time Metrics', fontweight='bold')
    ax_performance.axis('off')
    
    # Calculate metrics
    avg_activations = [np.mean(activation_history[cap]) if activation_history[cap] else 0 
                      for cap in capabilities]
    overall_avg = np.mean(avg_activations) if avg_activations else 0
    
    metrics_text = f"""
    Overall Activation: {overall_avg:.3f}
    
    Average by Capability:
    • Meta-Learning:    {avg_activations[0]:.3f}
    • Reasoning:        {avg_activations[1]:.3f}
    • Common Sense:     {avg_activations[2]:.3f}
    • Theory of Mind:   {avg_activations[3]:.3f}
    • Goal Emergence:   {avg_activations[4]:.3f}
    
    Tasks Processed: {len(task_history)}
    Status: ✅ OPERATIONAL
    """
    
    ax_performance.text(0.1, 0.5, metrics_text, fontsize=11, 
                       verticalalignment='center', family='monospace',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

print("Starting real-time visualization...")
print("Close the window to stop.\n")

# Animate
ani = animation.FuncAnimation(fig, update_visualization, interval=500, cache_frame_data=False)

plt.tight_layout()
plt.show()

print("\n✅ Visualization complete!")
