#!/usr/bin/env python3
"""
Eden Meta-Cognition Loop (MCL)
Top Priority Module: Enables Eden to evaluate, select, and learn from reasoning strategies

Purpose: Learn which reasoning strategies work best for different problem types
Algorithms: Thompson sampling + meta-gradient updates
Location: /Eden/CORE/phi_fractal/meta_cognition/meta_loop.py
"""

import json
import logging
import numpy as np
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
import random


# ============================================================================
# CONFIGURATION
# ============================================================================

EDEN_ROOT = Path("/Eden/CORE")
META_LOG_PATH = EDEN_ROOT / "logs" / "phi_meta.log"
META_STATE_PATH = EDEN_ROOT / "phi_fractal" / "meta_cognition" / "meta_policy_state.json"
STRATEGY_HISTORY_PATH = EDEN_ROOT / "phi_fractal" / "meta_cognition" / "strategy_history.jsonl"

# Create directories if they don't exist
META_STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
META_LOG_PATH.parent.mkdir(parents=True, exist_ok=True)

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - MCL - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(META_LOG_PATH),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


# ============================================================================
# DATA STRUCTURES
# ============================================================================

class StrategyType(Enum):
    """Available reasoning strategies"""
    DIRECT_REASONING = "direct_reasoning"  # Straightforward logical approach
    ANALOGICAL = "analogical"  # Use similar past problems
    DECOMPOSITION = "decomposition"  # Break into subproblems
    CAUSAL_CHAIN = "causal_chain"  # Trace cause-effect relationships
    COUNTERFACTUAL = "counterfactual"  # "What if" reasoning
    CONSTRAINT_SATISFACTION = "constraint_satisfaction"  # Find solution meeting constraints
    PATTERN_MATCHING = "pattern_matching"  # Recognize known patterns
    SIMULATION = "simulation"  # Mental simulation/rollout


@dataclass
class Observation:
    """Input observation/problem"""
    content: str
    task_type: Optional[str] = None
    domain: Optional[str] = None
    complexity: float = 0.5
    metadata: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}


@dataclass
class BeliefState:
    """Current belief/knowledge state"""
    confidence: float = 0.5
    uncertainty_estimate: float = 0.5
    relevant_schemas: List[str] = None
    active_goals: List[str] = None
    
    def __post_init__(self):
        if self.relevant_schemas is None:
            self.relevant_schemas = []
        if self.active_goals is None:
            self.active_goals = []


@dataclass
class StrategyResult:
    """Result of executing a strategy"""
    strategy: StrategyType
    success: bool
    confidence: float
    execution_time: float
    outcome: Any
    reasoning_trace: List[str]
    metadata: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}


@dataclass
class MetaPolicyState:
    """Meta-policy weights and statistics"""
    strategy_weights: Dict[str, float]
    strategy_counts: Dict[str, int]
    strategy_successes: Dict[str, int]
    strategy_avg_confidence: Dict[str, float]
    last_updated: str
    total_episodes: int = 0


# ============================================================================
# STRATEGY BANK
# ============================================================================

class StrategyBank:
    """Repository of reasoning strategies with execution templates"""
    
    def __init__(self):
        self.strategies = {
            StrategyType.DIRECT_REASONING: self._direct_reasoning,
            StrategyType.ANALOGICAL: self._analogical_reasoning,
            StrategyType.DECOMPOSITION: self._decomposition,
            StrategyType.CAUSAL_CHAIN: self._causal_chain,
            StrategyType.COUNTERFACTUAL: self._counterfactual,
            StrategyType.CONSTRAINT_SATISFACTION: self._constraint_satisfaction,
            StrategyType.PATTERN_MATCHING: self._pattern_matching,
            StrategyType.SIMULATION: self._simulation
        }
    
    def execute_strategy(self, strategy: StrategyType, observation: Observation, 
                        belief_state: BeliefState) -> StrategyResult:
        """Execute a reasoning strategy"""
        start_time = datetime.now()
        
        try:
            result = self.strategies[strategy](observation, belief_state)
            execution_time = (datetime.now() - start_time).total_seconds()
            
            return StrategyResult(
                strategy=strategy,
                success=result.get('success', False),
                confidence=result.get('confidence', 0.5),
                execution_time=execution_time,
                outcome=result.get('outcome'),
                reasoning_trace=result.get('reasoning_trace', []),
                metadata=result.get('metadata', {})
            )
        except Exception as e:
            logger.error(f"Strategy {strategy.value} failed: {e}")
            return StrategyResult(
                strategy=strategy,
                success=False,
                confidence=0.0,
                execution_time=0.0,
                outcome=None,
                reasoning_trace=[f"Error: {str(e)}"]
            )
    
    # Strategy implementations (lightweight placeholders - extend with real logic)
    
    def _direct_reasoning(self, obs: Observation, belief: BeliefState) -> Dict:
        """Straightforward logical approach"""
        trace = [
            "Analyzing problem directly",
            f"Task: {obs.content[:100]}",
            "Applying logical inference"
        ]
        
        # Simulate reasoning
        confidence = 0.7 if obs.complexity < 0.5 else 0.5
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Direct solution attempt: {'success' if success else 'failed'}",
            'reasoning_trace': trace,
            'metadata': {'strategy_type': 'direct'}
        }
    
    def _analogical_reasoning(self, obs: Observation, belief: BeliefState) -> Dict:
        """Use similar past problems"""
        trace = [
            "Searching for similar problems in memory",
            f"Found {len(belief.relevant_schemas)} relevant schemas",
            "Adapting previous solution"
        ]
        
        # Higher success if we have relevant schemas
        confidence = 0.8 if belief.relevant_schemas else 0.4
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Analogical solution: {'adapted from past' if success else 'no good analogy'}",
            'reasoning_trace': trace,
            'metadata': {'schemas_used': len(belief.relevant_schemas)}
        }
    
    def _decomposition(self, obs: Observation, belief: BeliefState) -> Dict:
        """Break into subproblems"""
        trace = [
            "Decomposing problem into subproblems",
            "Identified 3 subproblems",
            "Solving each independently",
            "Combining solutions"
        ]
        
        # Better for complex problems
        confidence = 0.8 if obs.complexity > 0.6 else 0.5
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Decomposed solution: {'combined successfully' if success else 'subproblem failed'}",
            'reasoning_trace': trace,
            'metadata': {'subproblems': 3}
        }
    
    def _causal_chain(self, obs: Observation, belief: BeliefState) -> Dict:
        """Trace cause-effect relationships"""
        trace = [
            "Building causal chain",
            "A → B → C → D (goal)",
            "Validating each causal link"
        ]
        
        confidence = 0.65
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Causal chain: {'validated' if success else 'broken link'}",
            'reasoning_trace': trace
        }
    
    def _counterfactual(self, obs: Observation, belief: BeliefState) -> Dict:
        """What-if reasoning"""
        trace = [
            "Generating counterfactual scenarios",
            "What if X was different?",
            "Comparing outcomes"
        ]
        
        confidence = 0.6
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Counterfactual: {'found better path' if success else 'no improvement'}",
            'reasoning_trace': trace
        }
    
    def _constraint_satisfaction(self, obs: Observation, belief: BeliefState) -> Dict:
        """Find solution meeting constraints"""
        trace = [
            "Identifying constraints",
            "Searching solution space",
            "Validating constraints"
        ]
        
        confidence = 0.7
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Constraint sat: {'feasible solution' if success else 'no feasible solution'}",
            'reasoning_trace': trace
        }
    
    def _pattern_matching(self, obs: Observation, belief: BeliefState) -> Dict:
        """Recognize known patterns"""
        trace = [
            "Scanning for known patterns",
            f"Pattern database: {len(belief.relevant_schemas)} patterns",
            "Matching..."
        ]
        
        confidence = 0.75 if belief.relevant_schemas else 0.3
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Pattern match: {'found' if success else 'no match'}",
            'reasoning_trace': trace
        }
    
    def _simulation(self, obs: Observation, belief: BeliefState) -> Dict:
        """Mental simulation/rollout"""
        trace = [
            "Running mental simulation",
            "Simulating 5 steps ahead",
            "Evaluating outcomes"
        ]
        
        confidence = 0.65
        success = random.random() < confidence
        
        return {
            'success': success,
            'confidence': confidence,
            'outcome': f"Simulation: {'predicted success' if success else 'predicted failure'}",
            'reasoning_trace': trace
        }


# ============================================================================
# META-POLICY MANAGER
# ============================================================================

class MetaPolicyManager:
    """Manages strategy selection using Thompson sampling and updates weights"""
    
    def __init__(self, state_path: Path = META_STATE_PATH):
        self.state_path = state_path
        self.state = self._load_or_initialize_state()
        
    def _load_or_initialize_state(self) -> MetaPolicyState:
        """Load existing state or create new one"""
        if self.state_path.exists():
            try:
                with open(self.state_path, 'r') as f:
                    data = json.load(f)
                logger.info(f"Loaded meta-policy state: {data['total_episodes']} episodes")
                return MetaPolicyState(**data)
            except Exception as e:
                logger.error(f"Failed to load state: {e}, initializing new")
        
        # Initialize with uniform weights
        strategies = [s.value for s in StrategyType]
        return MetaPolicyState(
            strategy_weights={s: 1.0 for s in strategies},
            strategy_counts={s: 0 for s in strategies},
            strategy_successes={s: 0 for s in strategies},
            strategy_avg_confidence={s: 0.5 for s in strategies},
            last_updated=datetime.now().isoformat(),
            total_episodes=0
        )
    
    def save_state(self):
        """Persist meta-policy state"""
        self.state.last_updated = datetime.now().isoformat()
        with open(self.state_path, 'w') as f:
            json.dump(asdict(self.state), f, indent=2)
        logger.info(f"Saved meta-policy state: {self.state.total_episodes} episodes")
    
    def select_strategies(self, n: int = 3, observation: Optional[Observation] = None) -> List[StrategyType]:
        """
        Select top N strategies using Thompson sampling
        
        Thompson sampling: sample from posterior distribution of each strategy's success rate
        """
        strategy_scores = {}
        
        for strategy_name in self.state.strategy_weights.keys():
            # Beta distribution parameters (Bayesian success rate estimation)
            successes = self.state.strategy_successes.get(strategy_name, 0)
            failures = self.state.strategy_counts.get(strategy_name, 0) - successes
            
            # Add pseudo-counts for exploration (Beta(1,1) prior)
            alpha = successes + 1
            beta = failures + 1
            
            # Sample from posterior
            sampled_success_rate = np.random.beta(alpha, beta)
            
            # Weight by average confidence and base weight
            base_weight = self.state.strategy_weights.get(strategy_name, 1.0)
            avg_confidence = self.state.strategy_avg_confidence.get(strategy_name, 0.5)
            
            score = sampled_success_rate * base_weight * avg_confidence
            
            # Context-aware adjustment (if observation provided)
            if observation:
                if observation.complexity > 0.7 and strategy_name == StrategyType.DECOMPOSITION.value:
                    score *= 1.3  # Prefer decomposition for complex problems
                if observation.complexity < 0.3 and strategy_name == StrategyType.DIRECT_REASONING.value:
                    score *= 1.2  # Prefer direct for simple problems
            
            strategy_scores[strategy_name] = score
        
        # Select top N
        top_strategies = sorted(strategy_scores.items(), key=lambda x: x[1], reverse=True)[:n]
        selected = [StrategyType(name) for name, _ in top_strategies]
        
        logger.info(f"Selected strategies: {[s.value for s in selected]}")
        return selected
    
    def update_from_result(self, result: StrategyResult):
        """Update meta-policy based on strategy outcome"""
        strategy_name = result.strategy.value
        
        # Update counts
        self.state.strategy_counts[strategy_name] = self.state.strategy_counts.get(strategy_name, 0) + 1
        
        if result.success:
            self.state.strategy_successes[strategy_name] = self.state.strategy_successes.get(strategy_name, 0) + 1
        
        # Update average confidence (running average)
        old_conf = self.state.strategy_avg_confidence.get(strategy_name, 0.5)
        count = self.state.strategy_counts[strategy_name]
        new_conf = (old_conf * (count - 1) + result.confidence) / count
        self.state.strategy_avg_confidence[strategy_name] = new_conf
        
        # Update weights using meta-gradient (simple version)
        # Increase weight for successful strategies, decrease for failures
        learning_rate = 0.1
        current_weight = self.state.strategy_weights[strategy_name]
        
        if result.success:
            # Increase weight proportional to confidence
            new_weight = current_weight + learning_rate * result.confidence
        else:
            # Decrease weight
            new_weight = current_weight - learning_rate * (1 - result.confidence)
        
        # Keep weights positive
        self.state.strategy_weights[strategy_name] = max(0.1, new_weight)
        
        self.state.total_episodes += 1
        
        logger.info(f"Updated {strategy_name}: weight={new_weight:.3f}, success_rate={self.get_success_rate(strategy_name):.3f}")
    
    def get_success_rate(self, strategy_name: str) -> float:
        """Calculate success rate for a strategy"""
        count = self.state.strategy_counts.get(strategy_name, 0)
        if count == 0:
            return 0.5  # Unknown
        successes = self.state.strategy_successes.get(strategy_name, 0)
        return successes / count
    
    def get_report(self) -> Dict[str, Any]:
        """Generate meta-policy report"""
        report = {
            'total_episodes': self.state.total_episodes,
            'last_updated': self.state.last_updated,
            'strategies': {}
        }
        
        for strategy_name in self.state.strategy_weights.keys():
            report['strategies'][strategy_name] = {
                'weight': self.state.strategy_weights[strategy_name],
                'count': self.state.strategy_counts[strategy_name],
                'successes': self.state.strategy_successes[strategy_name],
                'success_rate': self.get_success_rate(strategy_name),
                'avg_confidence': self.state.strategy_avg_confidence[strategy_name]
            }
        
        return report


# ============================================================================
# META-COGNITION LOOP
# ============================================================================

class MetaCognitionLoop:
    """
    Main MCL class: evaluates strategies, selects best, executes, and learns
    """
    
    def __init__(self):
        self.strategy_bank = StrategyBank()
        self.meta_policy = MetaPolicyManager()
        self.history = []
        
        logger.info("🧠 Meta-Cognition Loop initialized")
        logger.info(f"📊 Loaded {self.meta_policy.state.total_episodes} past episodes")
    
    def process(self, observation: Observation, belief_state: BeliefState, 
                n_candidates: int = 3, execute_best: bool = True) -> Dict[str, Any]:
        """
        Main processing loop:
        1. Select candidate strategies
        2. Simulate/evaluate each
        3. Choose best strategy
        4. Execute (optional)
        5. Update meta-policy
        """
        logger.info("=" * 70)
        logger.info(f"🎯 NEW TASK: {observation.content[:60]}...")
        logger.info("=" * 70)
        
        # Step 1: Select candidate strategies
        candidate_strategies = self.meta_policy.select_strategies(n_candidates, observation)
        logger.info(f"📋 Candidate strategies: {[s.value for s in candidate_strategies]}")
        
        # Step 2: Simulate each strategy (lightweight evaluation)
        results = []
        for strategy in candidate_strategies:
            logger.info(f"🔬 Simulating {strategy.value}...")
            result = self.strategy_bank.execute_strategy(strategy, observation, belief_state)
            results.append(result)
            logger.info(f"   → Success: {result.success}, Confidence: {result.confidence:.3f}")
        
        # Step 3: Choose best strategy
        # Score by success * confidence
        scored_results = [(r, r.confidence if r.success else 0) for r in results]
        scored_results.sort(key=lambda x: x[1], reverse=True)
        best_result = scored_results[0][0]
        
        logger.info(f"✨ CHOSEN STRATEGY: {best_result.strategy.value}")
        logger.info(f"   Confidence: {best_result.confidence:.3f}")
        
        # Step 4: Execute (if requested and in real mode)
        if execute_best:
            # In production, this would actually execute the strategy
            # For now, we use the simulated result
            executed_result = best_result
        else:
            executed_result = None
        
        # Step 5: Update meta-policy
        self.meta_policy.update_from_result(best_result)
        
        # Save to history
        episode = {
            'timestamp': datetime.now().isoformat(),
            'observation': asdict(observation),
            'candidates': [r.strategy.value for r in results],
            'chosen': best_result.strategy.value,
            'success': best_result.success,
            'confidence': best_result.confidence,
            'reasoning_trace': best_result.reasoning_trace
        }
        self.history.append(episode)
        self._save_episode(episode)
        
        # Periodic state save
        if len(self.history) % 10 == 0:
            self.meta_policy.save_state()
        
        return {
            'chosen_strategy': best_result.strategy.value,
            'all_candidates': [r.strategy.value for r in results],
            'success': best_result.success,
            'confidence': best_result.confidence,
            'reasoning_trace': best_result.reasoning_trace,
            'outcome': best_result.outcome,
            'meta_policy_updated': True
        }
    
    def _save_episode(self, episode: Dict):
        """Append episode to history file"""
        with open(STRATEGY_HISTORY_PATH, 'a') as f:
            f.write(json.dumps(episode) + '\n')
    
    def get_meta_report(self) -> Dict[str, Any]:
        """Get comprehensive meta-cognition report"""
        return {
            'mcl_status': 'operational',
            'total_episodes': len(self.history),
            'recent_history': self.history[-5:] if self.history else [],
            'meta_policy': self.meta_policy.get_report()
        }


# ============================================================================
# COMMAND-LINE INTERFACE
# ============================================================================

def test_mcl():
    """Test the Meta-Cognition Loop with sample problems"""
    print("\n" + "=" * 70)
    print("🧠 EDEN META-COGNITION LOOP - TEST MODE")
    print("=" * 70 + "\n")
    
    mcl = MetaCognitionLoop()
    
    # Test problems
    test_problems = [
        Observation(
            content="How can I optimize this recursive function for factorial?",
            task_type="code_optimization",
            domain="programming",
            complexity=0.6,
            metadata={'language': 'python'}
        ),
        Observation(
            content="What causes the stock market to crash?",
            task_type="causal_analysis",
            domain="economics",
            complexity=0.8
        ),
        Observation(
            content="2 + 2 = ?",
            task_type="arithmetic",
            domain="math",
            complexity=0.1
        ),
        Observation(
            content="Design a distributed system that can handle 1M requests/second",
            task_type="system_design",
            domain="engineering",
            complexity=0.9
        ),
    ]
    
    # Simulate belief states
    belief_states = [
        BeliefState(confidence=0.7, relevant_schemas=["recursion", "optimization"]),
        BeliefState(confidence=0.4, relevant_schemas=[], uncertainty_estimate=0.7),
        BeliefState(confidence=0.95, relevant_schemas=["arithmetic"]),
        BeliefState(confidence=0.5, relevant_schemas=["distributed_systems"], uncertainty_estimate=0.6),
    ]
    
    # Process each problem
    for obs, belief in zip(test_problems, belief_states):
        result = mcl.process(obs, belief, n_candidates=3, execute_best=True)
        
        print(f"\n📝 Problem: {obs.content}")
        print(f"✅ Chosen: {result['chosen_strategy']}")
        print(f"💯 Confidence: {result['confidence']:.3f}")
        print(f"🎯 Success: {result['success']}")
        print(f"💭 Reasoning: {result['reasoning_trace'][0]}")
        print()
    
    # Show final report
    print("\n" + "=" * 70)
    print("📊 META-COGNITION REPORT")
    print("=" * 70)
    
    report = mcl.get_meta_report()
    print(f"\nTotal episodes: {report['total_episodes']}")
    print(f"\nStrategy Performance:")
    
    for strategy_name, stats in report['meta_policy']['strategies'].items():
        if stats['count'] > 0:
            print(f"\n  {strategy_name}:")
            print(f"    Weight: {stats['weight']:.3f}")
            print(f"    Uses: {stats['count']}")
            print(f"    Success rate: {stats['success_rate']:.1%}")
            print(f"    Avg confidence: {stats['avg_confidence']:.3f}")
    
    # Save final state
    mcl.meta_policy.save_state()
    print(f"\n💾 State saved to: {META_STATE_PATH}")
    print(f"📜 History saved to: {STRATEGY_HISTORY_PATH}")
    print("\n✅ Meta-Cognition Loop test complete!\n")


def main():
    """Main entry point"""
    import sys
    
    if len(sys.argv) > 1 and sys.argv[1] == "test":
        test_mcl()
    elif len(sys.argv) > 1 and sys.argv[1] == "report":
        # Show current meta-policy state
        mcl = MetaCognitionLoop()
        report = mcl.get_meta_report()
        print(json.dumps(report, indent=2))
    else:
        print("Eden Meta-Cognition Loop")
        print("\nUsage:")
        print("  python3 meta_loop.py test       # Run test scenarios")
        print("  python3 meta_loop.py report     # Show current meta-policy state")
        print("\nIntegration:")
        print("  from meta_loop import MetaCognitionLoop, Observation, BeliefState")
        print("  mcl = MetaCognitionLoop()")
        print("  result = mcl.process(observation, belief_state)")


if __name__ == "__main__":
    main()
