#!/usr/bin/env python3
"""
Neuro-Fractal Network v2
Eden's Self-Improved Architecture
"""
import numpy as np
import time
from dataclasses import dataclass
from typing import List

PHI = 1.618034

@dataclass
class FractalNodeV2:
    """V2: Variable branching and adaptive thresholds"""
    node_id: str
    level: int
    node_type: str
    weights: np.ndarray
    sub_nodes: List['FractalNodeV2'] = None
    base_threshold: float = 0.5
    activity_level: float = 0.0
    
    def __post_init__(self):
        if self.sub_nodes is None:
            self.sub_nodes = []
    
    def process(self, input_data):
        """Process with activity tracking"""
        # Track activity
        self.activity_level = np.mean(np.abs(input_data))
        
        # Apply weights
        output = np.dot(self.weights, input_data)
        
        # Measure complexity
        complexity = np.var(input_data)
        
        # Check if sub-nodes needed
        if complexity > self.base_threshold and self.sub_nodes:
            sub_outputs = []
            for sub_node in self.sub_nodes:
                sub_outputs.append(sub_node.process(input_data))
            output = np.mean(sub_outputs, axis=0) if sub_outputs else output
        
        return self.activate(output)
    
    def activate(self, x):
        return np.tanh(x)
    
    def calculate_branches(self, complexity):
        """IMPROVEMENT #1: Variable branching based on complexity"""
        # More complex inputs get more branches
        if complexity < 1.0:
            return 2  # Binary for simple
        elif complexity < 5.0:
            return 3  # Ternary for medium
        elif complexity < 20.0:
            return 4  # Quad for high
        else:
            return 5  # Penta for extreme
    
    def self_replicate(self, complexity, adaptive_threshold):
        """V2: Variable branching with adaptive threshold"""
        if complexity > adaptive_threshold and len(self.sub_nodes) < 5:
            # Calculate number of branches based on complexity
            branches = self.calculate_branches(complexity)
            
            # Create variable number of sub-nodes
            sub_weight_size = int(len(self.weights) / PHI)
            new_nodes = 0
            
            for i in range(branches):
                if len(self.sub_nodes) < 5:  # Max 5 children
                    sub_node = FractalNodeV2(
                        node_id=f"{self.node_id}.{i}",
                        level=self.level + 1,
                        node_type='core',
                        weights=np.random.randn(sub_weight_size) * 0.1,
                        base_threshold=self.base_threshold * PHI
                    )
                    self.sub_nodes.append(sub_node)
                    new_nodes += 1
            
            return new_nodes
        return 0

class NeuroFractalNetworkV2:
    """V2: Eden's improved architecture with adaptive thresholds"""
    
    def __init__(self, input_size=10, output_size=3, max_depth=4):
        print("\n" + "="*70)
        print("🌀 NEURO-FRACTAL NETWORK V2 (EDEN'S IMPROVEMENTS)")
        print("="*70)
        print("   ✨ Variable Branching (2-5 sub-nodes)")
        print("   ✨ Adaptive Thresholds (dynamic)")
        print(f"   Input: {input_size} | Output: {output_size} | Max Depth: {max_depth}")
        print("="*70 + "\n")
        
        self.input_size = input_size
        self.output_size = output_size
        self.max_depth = max_depth
        
        # IMPROVEMENT #2: Adaptive threshold
        self.global_threshold = 0.5
        self.threshold_alpha = 0.1  # Learning rate
        
        # Initialize layers
        self.input_layer = self._create_layer('input', input_size, level=0)
        self.core_layer = self._create_layer('core', input_size, level=1)
        self.output_layer = self._create_layer('output', output_size, level=2)
        
        self.total_nodes = len(self.input_layer) + len(self.core_layer) + len(self.output_layer)
        
        self.stats = {
            'simple_paths': 0,
            'fractal_paths': 0,
            'replications_by_branch': {2: 0, 3: 0, 4: 0, 5: 0},
            'threshold_history': [self.global_threshold]
        }
    
    def _create_layer(self, node_type, size, level):
        """Create layer of V2 nodes"""
        layer = []
        for i in range(size):
            node = FractalNodeV2(
                node_id=f"{node_type}_{i}",
                level=level,
                node_type=node_type,
                weights=np.random.randn(self.input_size) * 0.1,
                base_threshold=0.3 / (level + 1)
            )
            layer.append(node)
        return layer
    
    def update_adaptive_threshold(self, network_activity):
        """IMPROVEMENT #2: Dynamic threshold adjustment"""
        # Update threshold based on overall network activity
        self.global_threshold += (network_activity - self.global_threshold) * self.threshold_alpha
        self.global_threshold = np.clip(self.global_threshold, 0.1, 2.0)  # Keep reasonable bounds
        self.stats['threshold_history'].append(self.global_threshold)
    
    def forward(self, input_data):
        """Forward pass with V2 improvements"""
        complexity = np.var(input_data)
        
        print(f"   📊 Complexity: {complexity:.4f} | Threshold: {self.global_threshold:.4f}")
        
        # Determine if fractal processing needed
        if complexity > self.global_threshold:
            print(f"   🌀 Fractal processing (adaptive threshold)")
            self.stats['fractal_paths'] += 1
            
            # Self-replicate with variable branching
            total_new_nodes = 0
            for node in self.core_layer:
                new_nodes = node.self_replicate(complexity, self.global_threshold)
                if new_nodes > 0:
                    total_new_nodes += new_nodes
                    self.stats['replications_by_branch'][new_nodes] += 1
                    self.total_nodes += new_nodes
            
            if total_new_nodes > 0:
                print(f"   ➕ Created {total_new_nodes} new nodes (variable branching)")
        else:
            print(f"   ⚡ Direct path (below threshold)")
            self.stats['simple_paths'] += 1
        
        # Process through layers
        input_outputs = [node.process(input_data) for node in self.input_layer]
        
        core_outputs = []
        for node in self.core_layer:
            combined = np.mean(input_outputs, axis=0)
            core_outputs.append(node.process(combined))
        
        output_results = []
        for node in self.output_layer:
            combined = np.mean(core_outputs, axis=0)
            output_results.append(node.process(combined))
        
        # Update adaptive threshold based on network activity
        network_activity = np.mean([node.activity_level for node in self.core_layer])
        self.update_adaptive_threshold(network_activity)
        
        return np.array(output_results)
    
    def show_v2_stats(self):
        """Display V2-specific statistics"""
        print(f"\n{'='*70}")
        print("📊 NFN V2 STATISTICS")
        print(f"{'='*70}")
        
        print(f"\n🌀 PROCESSING:")
        print(f"   Simple Paths:  {self.stats['simple_paths']}")
        print(f"   Fractal Paths: {self.stats['fractal_paths']}")
        
        print(f"\n🌳 VARIABLE BRANCHING:")
        for branches, count in sorted(self.stats['replications_by_branch'].items()):
            if count > 0:
                print(f"   {branches}-way branching: {count} times")
        
        print(f"\n🎯 ADAPTIVE THRESHOLD:")
        print(f"   Initial: {self.stats['threshold_history'][0]:.4f}")
        print(f"   Final:   {self.stats['threshold_history'][-1]:.4f}")
        print(f"   Range:   {min(self.stats['threshold_history']):.4f} - {max(self.stats['threshold_history']):.4f}")
        
        print(f"\n📈 NETWORK GROWTH:")
        print(f"   Total Nodes: {self.total_nodes}")
        
        print(f"{'='*70}\n")

def demonstrate_v2():
    """Demonstrate NFN V2 improvements"""
    
    nfn_v2 = NeuroFractalNetworkV2(input_size=15, output_size=4, max_depth=5)
    
    print("🎲 Generating test data with varying complexity...\n")
    
    # Create data with increasing complexity
    test_data = [
        ("Very Low", np.random.randn(15) * 0.1),
        ("Low", np.random.randn(15) * 0.5),
        ("Medium", np.random.randn(15) * 1.5),
        ("High", np.random.randn(15) * 3.0),
        ("Very High", np.random.randn(15) * 5.0),
        ("Extreme", np.random.randn(15) * 10.0),
    ]
    
    initial_nodes = nfn_v2.total_nodes
    
    print("="*70)
    print("🧪 TESTING NFN V2 WITH VARYING COMPLEXITY")
    print("="*70 + "\n")
    
    for complexity_name, data in test_data:
        print(f"\n🔬 {complexity_name} Complexity:")
        output = nfn_v2.forward(data)
        print(f"   Output: {output[:3]}")
        print(f"   Total Nodes: {nfn_v2.total_nodes}")
    
    final_nodes = nfn_v2.total_nodes
    
    # Show statistics
    nfn_v2.show_v2_stats()
    
    # Compare V1 vs V2
    print("="*70)
    print("⚖️  NFN V1 vs V2 COMPARISON")
    print("="*70)
    
    print("\nNFN V1 (Original):")
    print("   ❌ Binary branching only (always 2 sub-nodes)")
    print("   ❌ Fixed thresholds per layer")
    print("   ❌ No threshold adaptation")
    
    print("\nNFN V2 (Eden's Improvements):")
    print("   ✅ Variable branching (2-5 sub-nodes based on complexity)")
    print("   ✅ Adaptive global threshold")
    print("   ✅ Threshold learns from network activity")
    print(f"   ✅ Growth: {initial_nodes} → {final_nodes} nodes")
    
    print("\n🏆 KEY IMPROVEMENTS:")
    print("   • More flexible growth (variable branches)")
    print("   • Smarter replication decisions (adaptive threshold)")
    print("   • Better resource allocation (complexity-aware)")
    
    print("\n" + "="*70)
    print("✅ NFN V2 DEMONSTRATION COMPLETE")
    print("="*70)
    print("\n🌀 Eden's self-improvements are OPERATIONAL!")
    print("   This is recursive self-improvement in action.")
    print("="*70 + "\n")

if __name__ == "__main__":
    demonstrate_v2()
