"""
Meta-Cognition Loop
System that thinks about its own thinking and learns how to learn
"""

import numpy as np
from typing import List, Dict, Callable, Optional, Any
import yaml
import logging
import json
from pathlib import Path
from dataclasses import dataclass
from datetime import datetime
import time

logger = logging.getLogger(__name__)

@dataclass
class Strategy:
    """Represents a problem-solving strategy"""
    name: str
    description: str
    execute_fn: Optional[Callable] = None
    domain: str = "general"
    expected_success_rate: float = 0.5
    avg_execution_time: float = 1.0
    usage_count: int = 0
    success_count: int = 0

@dataclass
class Simulation:
    """Result of strategy simulation"""
    strategy_name: str
    predicted_success: float
    predicted_cost: float
    predicted_outcome: str
    confidence: float
    utility: float = 0.0

class MetaCognitionLoop:
    """Meta-learning system that selects and adapts strategies"""
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        # Add meta-cognition config if not exists
        if 'meta_cognition' not in config:
            config['meta_cognition'] = {
                'num_strategies_per_query': 5,
                'simulation_depth': 3,
                'confidence_threshold': 0.7,
                'policy_update_rate': 0.1,
                'min_usage_before_trust': 3
            }
        
        self.config = config['meta_cognition']
        self.num_strategies = self.config['num_strategies_per_query']
        self.confidence_threshold = self.config['confidence_threshold']
        self.update_rate = self.config['policy_update_rate']
        self.min_usage = self.config.get('min_usage_before_trust', 3)
        
        # Strategy library
        self.strategies: List[Strategy] = []
        self.policy_path = Path("/Eden/MEMORY/policies/learned_strategies.json")
        self.policy_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Load existing policies
        self._load_policies()
        
        # Decision history
        self.simulation_history: List[Dict] = []
        
        # Performance tracking
        self.prediction_errors = []
        
        logger.info(f"MetaCognitionLoop initialized with {len(self.strategies)} strategies")
    
    def _load_policies(self):
        """Load saved strategy policies"""
        if not self.policy_path.exists():
            logger.info("No existing policies, starting fresh")
            return
        
        try:
            with open(self.policy_path) as f:
                policies = json.load(f)
            
            for policy in policies:
                strategy = Strategy(
                    name=policy['name'],
                    description=policy['description'],
                    execute_fn=None,
                    domain=policy.get('domain', 'general'),
                    expected_success_rate=policy.get('expected_success_rate', 0.5),
                    avg_execution_time=policy.get('avg_execution_time', 1.0),
                    usage_count=policy.get('usage_count', 0),
                    success_count=policy.get('success_count', 0)
                )
                self.strategies.append(strategy)
            
            logger.info(f"Loaded {len(self.strategies)} strategies")
        except Exception as e:
            logger.error(f"Failed to load policies: {e}")
    
    def register_strategy(self, 
                         name: str, 
                         execute_fn: Callable, 
                         description: str = "",
                         domain: str = "general"):
        """Register a new strategy function"""
        existing = next((s for s in self.strategies if s.name == name), None)
        
        if existing:
            existing.execute_fn = execute_fn
            logger.info(f"Strategy updated: {name}")
        else:
            strategy = Strategy(
                name=name,
                description=description,
                execute_fn=execute_fn,
                domain=domain
            )
            self.strategies.append(strategy)
            logger.info(f"Strategy registered: {name} ({domain})")
    
    def select_strategy(self, 
                       problem: Dict, 
                       context: Optional[Dict] = None) -> Dict:
        """
        Select best strategy using meta-reasoning
        
        Args:
            problem: Problem dict with keys:
                - description: str
                - domain: str (optional)
                - complexity: float 0-1 (optional)
            context: Additional context dict
            
        Returns:
            Selected strategy with reasoning
        """
        if context is None:
            context = {}
        
        # Generate candidates
        candidates = self._generate_candidates(problem, context)
        
        if not candidates:
            logger.warning("No strategies available")
            return {
                'strategy': None,
                'simulation': None,
                'alternatives': [],
                'confidence': 0.0,
                'reasoning': "No strategies registered"
            }
        
        # Simulate each candidate
        simulations = []
        for candidate in candidates[:self.num_strategies]:
            sim = self._simulate_strategy(candidate, problem, context)
            simulations.append(sim)
        
        # Rank by utility
        ranked = self._rank_by_utility(simulations, context)
        
        # Select best
        best = ranked[0]
        selected_strategy = next(s for s in self.strategies if s.name == best.strategy_name)
        
        # Log decision
        decision_log = {
            'timestamp': datetime.now().isoformat(),
            'problem_description': problem.get('description', '')[:200],
            'problem_domain': problem.get('domain', 'general'),
            'candidates': [s.strategy_name for s in simulations],
            'selected': best.strategy_name,
            'confidence': best.confidence,
            'predicted_success': best.predicted_success
        }
        
        self.simulation_history.append(decision_log)
        
        logger.info(f"Selected: {best.strategy_name} (confidence: {best.confidence:.2f})")
        
        return {
            'strategy': selected_strategy,
            'simulation': best,
            'alternatives': ranked[1:],
            'confidence': best.confidence,
            'reasoning': f"Selected '{best.strategy_name}' with {best.predicted_success:.1%} predicted success"
        }
    
    def _generate_candidates(self, problem: Dict, context: Dict) -> List[Strategy]:
        """Generate candidate strategies for problem"""
        if not self.strategies:
            return []
        
        problem_domain = problem.get('domain', 'general')
        
        # Domain-specific strategies
        domain_strategies = [s for s in self.strategies 
                            if s.domain == problem_domain and s.execute_fn is not None]
        
        # General strategies
        general_strategies = [s for s in self.strategies 
                            if s.domain == 'general' and s.execute_fn is not None]
        
        all_candidates = domain_strategies + general_strategies
        
        # Sort by success rate + exploration bonus
        def score_strategy(s: Strategy) -> float:
            base_score = s.success_count / max(s.usage_count, 1)
            
            # Exploration bonus for under-used strategies
            if s.usage_count < self.min_usage:
                exploration_bonus = 0.3
            else:
                total_uses = sum(st.usage_count for st in self.strategies) + 1
                exploration_bonus = 0.1 * np.sqrt(np.log(total_uses) / s.usage_count)
            
            return base_score + exploration_bonus
        
        all_candidates.sort(key=score_strategy, reverse=True)
        return all_candidates
    
    def _simulate_strategy(self, 
                          strategy: Strategy, 
                          problem: Dict,
                          context: Dict) -> Simulation:
        """Mental simulation of strategy (doesn't actually execute)"""
        # Base prediction
        base_success = strategy.expected_success_rate
        
        # Adjust for complexity
        complexity = problem.get('complexity', 0.5)
        complexity_penalty = complexity * 0.3
        
        # Adjust for time constraints
        time_limit = context.get('time_limit', float('inf'))
        time_feasible = 1.0 if strategy.avg_execution_time < time_limit else 0.3
        
        # Novelty penalty
        novelty_penalty = 0.2 if strategy.usage_count < self.min_usage else 0.0
        
        # Combined prediction
        predicted_success = base_success * (1 - complexity_penalty) * time_feasible
        predicted_success = max(0.1, min(0.95, predicted_success - novelty_penalty))
        
        # Predicted cost
        predicted_cost = (strategy.avg_execution_time / 10.0) * (1 + complexity)
        predicted_cost = min(1.0, predicted_cost)
        
        # Confidence based on usage
        if strategy.usage_count == 0:
            confidence = 0.3
        elif strategy.usage_count < self.min_usage:
            confidence = 0.3 + 0.2 * (strategy.usage_count / self.min_usage)
        else:
            confidence = min(0.9, 0.5 + 0.05 * np.log1p(strategy.usage_count))
        
        return Simulation(
            strategy_name=strategy.name,
            predicted_success=predicted_success,
            predicted_cost=predicted_cost,
            predicted_outcome=f"Success: {predicted_success:.1%}, Cost: {predicted_cost:.2f}",
            confidence=confidence
        )
    
    def _rank_by_utility(self, 
                        simulations: List[Simulation],
                        context: Dict) -> List[Simulation]:
        """Rank strategies by expected utility"""
        problem_value = context.get('value', 1.0)
        risk_tolerance = context.get('risk_tolerance', 0.5)
        
        for sim in simulations:
            expected_value = sim.predicted_success * problem_value
            
            # Risk adjustment
            if risk_tolerance < 0.5:
                confidence_weight = (1 - risk_tolerance) * 0.5
                expected_value *= (1 - confidence_weight + confidence_weight * sim.confidence)
            
            utility = expected_value - sim.predicted_cost
            sim.utility = utility
        
        simulations.sort(key=lambda s: s.utility, reverse=True)
        return simulations
    
    def update_strategy(self, 
                       strategy_name: str,
                       actual_success: bool,
                       actual_execution_time: float,
                       predicted_success: Optional[float] = None):
        """Update strategy based on actual outcome (learning step!)"""
        strategy = next((s for s in self.strategies if s.name == strategy_name), None)
        if not strategy:
            logger.warning(f"Strategy not found: {strategy_name}")
            return
        
        # Update counts
        strategy.usage_count += 1
        if actual_success:
            strategy.success_count += 1
        
        # Update success rate (exponential moving average)
        current_rate = strategy.expected_success_rate
        actual_rate = 1.0 if actual_success else 0.0
        strategy.expected_success_rate = (
            (1 - self.update_rate) * current_rate + 
            self.update_rate * actual_rate
        )
        
        # Update execution time
        strategy.avg_execution_time = (
            (1 - self.update_rate) * strategy.avg_execution_time +
            self.update_rate * actual_execution_time
        )
        
        # Track prediction error
        if predicted_success is not None:
            error = abs(predicted_success - actual_rate)
            self.prediction_errors.append(error)
            if len(self.prediction_errors) > 100:
                self.prediction_errors.pop(0)
        
        # Save updated policies
        self._save_policies()
        
        logger.info(
            f"Strategy '{strategy_name}' updated: "
            f"{strategy.success_count}/{strategy.usage_count} "
            f"({strategy.expected_success_rate:.1%})"
        )
    
    def _save_policies(self):
        """Persist strategy policies to disk"""
        policies = []
        
        for strategy in self.strategies:
            policy = {
                'name': strategy.name,
                'description': strategy.description,
                'domain': strategy.domain,
                'expected_success_rate': strategy.expected_success_rate,
                'avg_execution_time': strategy.avg_execution_time,
                'usage_count': strategy.usage_count,
                'success_count': strategy.success_count
            }
            policies.append(policy)
        
        with open(self.policy_path, 'w') as f:
            json.dump(policies, f, indent=2)
    
    def get_metrics(self) -> Dict:
        """Get meta-cognition metrics"""
        if not self.strategies:
            return {
                'total_strategies': 0,
                'avg_success_rate': 0.0,
                'total_decisions': 0,
                'prediction_accuracy': 0.5
            }
        
        # Prediction accuracy
        if self.prediction_errors:
            avg_error = np.mean(self.prediction_errors)
            prediction_accuracy = 1.0 - avg_error
        else:
            prediction_accuracy = 0.5
        
        return {
            'total_strategies': len(self.strategies),
            'avg_success_rate': np.mean([s.expected_success_rate for s in self.strategies]),
            'total_decisions': len(self.simulation_history),
            'prediction_accuracy': prediction_accuracy
        }
