#!/usr/bin/env python3
"""
COMPREHENSIVE TEST SUITE FOR EDEN AGI
Tests all capabilities individually and in combination
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import sys
import os

# Add capabilities to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'capabilities'))

device = torch.device('cuda')

print("="*70)
print("EDEN AGI - COMPREHENSIVE TEST SUITE")
print("="*70)

# Define the agent architecture (same as trained)
class UnifiedEdenAgent(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.perception = nn.Sequential(
            nn.Linear(100, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        self.cognitive_core = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256)
        )
        
        self.meta_learning = nn.Linear(256, 64)
        self.reasoning = nn.Linear(256, 64)
        self.common_sense = nn.Linear(256, 64)
        self.theory_of_mind = nn.Linear(256, 64)
        self.goals = nn.Linear(256, 64)
        
        self.integration = nn.Sequential(
            nn.Linear(64 * 5, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.output = nn.Linear(128, 10)
        
    def forward(self, x):
        perceived = self.perception(x)
        cognitive = self.cognitive_core(perceived)
        
        meta = self.meta_learning(cognitive)
        reason = self.reasoning(cognitive)
        cs = self.common_sense(cognitive)
        tom = self.theory_of_mind(cognitive)
        goal = self.goals(cognitive)
        
        integrated = self.integration(torch.cat([meta, reason, cs, tom, goal], dim=1))
        
        return self.output(integrated)

# Load agent
agent = UnifiedEdenAgent().to(device)
checkpoint = torch.load('capabilities/unified_eden_working.pth', weights_only=False)
agent.load_state_dict(checkpoint['model_state'])
agent.eval()

print(f"\n✅ Unified Agent Loaded (Performance: {checkpoint['performance']*100:.1f}%)")

# ============================================================================
# TEST 1: INDIVIDUAL CAPABILITY TESTS
# ============================================================================

print("\n" + "="*70)
print("TEST 1: INDIVIDUAL CAPABILITY VERIFICATION")
print("="*70)

def test_capability(task_id, start_idx):
    """Generic capability test"""
    correct = 0
    for _ in range(20):
        x = torch.zeros(1, 100).to(device)
        x[0, start_idx:start_idx+10] = 1
        x[0, start_idx+10:start_idx+20] = torch.randn(10).to(device)
        
        with torch.no_grad():
            pred = agent(x)
            if pred.argmax() == task_id:
                correct += 1
    
    return correct / 20

print("\nTesting individual capabilities:")
results = {}

results['Meta-Learning'] = test_capability(0, 0)
print(f"  ✓ Meta-Learning: {results['Meta-Learning']*100:.0f}%")

results['Reasoning'] = test_capability(1, 20)
print(f"  ✓ Advanced Reasoning: {results['Reasoning']*100:.0f}%")

results['Common Sense'] = test_capability(2, 40)
print(f"  ✓ Common Sense: {results['Common Sense']*100:.0f}%")

results['Theory of Mind'] = test_capability(3, 60)
print(f"  ✓ Theory of Mind: {results['Theory of Mind']*100:.0f}%")

results['Goal Emergence'] = test_capability(4, 80)
print(f"  ✓ Goal Emergence: {results['Goal Emergence']*100:.0f}%")

avg_individual = np.mean(list(results.values()))
print(f"\nAverage Individual Performance: {avg_individual*100:.1f}%")
status = "✅ PASS" if avg_individual >= 0.95 else "✅ GOOD" if avg_individual >= 0.85 else "⚠️ WARNING"
print(f"Status: {status}")

# ============================================================================
# TEST 2: STRESS TESTING
# ============================================================================

print("\n" + "="*70)
print("TEST 2: STRESS TESTING")
print("="*70)

def stress_test_throughput():
    """Test processing speed"""
    batch_sizes = [1, 10, 100, 1000]
    throughput = {}
    
    print("\nThroughput Test:")
    for batch_size in batch_sizes:
        x = torch.randn(batch_size, 100).to(device)
        
        # Warmup
        with torch.no_grad():
            _ = agent(x)
        
        start = time.time()
        with torch.no_grad():
            for _ in range(10):
                _ = agent(x)
        elapsed = time.time() - start
        
        samples_per_sec = (batch_size * 10) / elapsed
        throughput[batch_size] = samples_per_sec
        print(f"  Batch {batch_size:4d}: {samples_per_sec:8.0f} samples/sec")
    
    return throughput

def stress_test_noise():
    """Test robustness to noise"""
    noise_levels = [0.0, 0.1, 0.5, 1.0, 2.0]
    
    print("\nNoise Robustness Test:")
    for noise_level in noise_levels:
        correct = 0
        total = 0
        
        for task_id in range(5):
            x = torch.zeros(10, 100).to(device)
            
            # Set signature
            start_idx = task_id * 20
            x[:, start_idx:start_idx+10] = 1
            x[:, start_idx+10:start_idx+20] = torch.randn(10, 10).to(device)
            
            # Add noise
            x = x + torch.randn_like(x) * noise_level
            
            with torch.no_grad():
                pred = agent(x)
                correct += (pred.argmax(1) == task_id).sum().item()
                total += 10
        
        acc = correct / total
        status = "✓" if acc >= 0.7 else "⚠" if acc >= 0.5 else "✗"
        print(f"  {status} Noise {noise_level:.1f}: {acc*100:.0f}%")

def stress_test_edge_cases():
    """Test extreme inputs"""
    print("\nEdge Cases Test:")
    
    edge_cases = {
        'All Zeros': torch.zeros(10, 100).to(device),
        'All Ones': torch.ones(10, 100).to(device),
        'Very Large': torch.randn(10, 100).to(device) * 100,
        'Very Small': torch.randn(10, 100).to(device) * 0.001,
    }
    
    for name, x in edge_cases.items():
        with torch.no_grad():
            try:
                pred = agent(x)
                if torch.isnan(pred).any() or torch.isinf(pred).any():
                    print(f"  ✗ {name}: FAIL (NaN/Inf)")
                else:
                    print(f"  ✓ {name}: PASS")
            except Exception as e:
                print(f"  ✗ {name}: FAIL ({str(e)[:30]})")

throughput = stress_test_throughput()
stress_test_noise()
stress_test_edge_cases()

# ============================================================================
# TEST 3: PERFORMANCE METRICS
# ============================================================================

print("\n" + "="*70)
print("TEST 3: PERFORMANCE METRICS")
print("="*70)

def measure_latency():
    """Measure inference latency"""
    x = torch.randn(1, 100).to(device)
    
    # Warmup
    for _ in range(10):
        with torch.no_grad():
            _ = agent(x)
    
    # Measure
    latencies = []
    for _ in range(100):
        start = time.time()
        with torch.no_grad():
            _ = agent(x)
        latencies.append((time.time() - start) * 1000)  # ms
    
    return {
        'mean': np.mean(latencies),
        'std': np.std(latencies),
        'p50': np.percentile(latencies, 50),
        'p95': np.percentile(latencies, 95),
        'p99': np.percentile(latencies, 99)
    }

def measure_memory():
    """Measure memory usage"""
    total_params = sum(p.numel() for p in agent.parameters())
    param_memory = sum(p.numel() * p.element_size() for p in agent.parameters()) / 1024**2
    
    return {
        'total_params': total_params,
        'memory_mb': param_memory
    }

print("\nLatency (100 runs):")
latency = measure_latency()
print(f"  Mean: {latency['mean']:.2f} ms")
print(f"  Std:  {latency['std']:.2f} ms")
print(f"  P95:  {latency['p95']:.2f} ms")

print("\nMemory Footprint:")
memory = measure_memory()
print(f"  Parameters: {memory['total_params']:,}")
print(f"  Memory: {memory['memory_mb']:.1f} MB")

# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "="*70)
print("COMPREHENSIVE TEST SUMMARY")
print("="*70)

print(f"""
✓ Individual Capabilities:  {avg_individual*100:.1f}%
✓ Throughput (batch=100):   {throughput.get(100, 0):.0f} samples/sec
✓ Latency (mean):           {latency['mean']:.1f} ms
✓ Model Size:               {memory['memory_mb']:.1f} MB

Overall: {'✅ EXCELLENT' if avg_individual >= 0.95 else '✅ GOOD' if avg_individual >= 0.85 else '⚠️ REVIEW'}
""")

print("="*70)
print("✅ COMPREHENSIVE TESTING COMPLETE")
print("="*70)
