#!/usr/bin/env python3
"""
EDEN WORLD MODEL - REAL AGI
===========================
Actual causal reasoning using:
1. Causal graphs (DAGs)
2. Do-calculus for interventions
3. Bayesian belief networks
4. Clingo for logical inference

NOT keyword matching. REAL computation.

φ = 1.618033988749895
"""

import sqlite3
import json
import math
import subprocess
from datetime import datetime
from typing import Dict, List, Set, Tuple, Optional
from collections import defaultdict
import heapq

PHI = 1.618033988749895

class CausalGraph:
    """
    Real causal graph with do-calculus.
    Nodes are variables, edges are causal relationships with strengths.
    """
    
    def __init__(self):
        # Adjacency list: node -> [(child, strength, evidence_count)]
        self.edges: Dict[str, List[Tuple[str, float, int]]] = defaultdict(list)
        # Reverse edges for backtracking
        self.parents: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
        # Node states
        self.states: Dict[str, float] = {}  # node -> probability of being true
        # Observed nodes (can't be intervened on in this context)
        self.observed: Set[str] = set()
    
    def add_edge(self, cause: str, effect: str, strength: float = 0.5, evidence: int = 1):
        """Add causal edge: cause -> effect with strength."""
        # Update or add edge
        found = False
        for i, (child, s, e) in enumerate(self.edges[cause]):
            if child == effect:
                # Bayesian update of strength
                new_evidence = e + evidence
                new_strength = (s * e + strength * evidence) / new_evidence
                self.edges[cause][i] = (child, new_strength, new_evidence)
                found = True
                break
        
        if not found:
            self.edges[cause].append((effect, strength, evidence))
            self.parents[effect].append((cause, strength))
        
        # Initialize states if not present
        if cause not in self.states:
            self.states[cause] = 0.5
        if effect not in self.states:
            self.states[effect] = 0.5
    
    def observe(self, node: str, value: float):
        """Observe a node's state (conditioning)."""
        self.states[node] = value
        self.observed.add(node)
        self._propagate_observation(node)
    
    def _propagate_observation(self, node: str):
        """Propagate observation through the graph (belief propagation)."""
        # Forward propagation to children
        queue = [node]
        visited = {node}
        
        while queue:
            current = queue.pop(0)
            current_state = self.states[current]
            
            for child, strength, _ in self.edges[current]:
                if child not in self.observed:
                    # P(child | parent) = parent_state * strength + (1-strength) * prior
                    prior = self.states.get(child, 0.5)
                    new_state = current_state * strength + (1 - strength) * prior
                    self.states[child] = new_state
                    
                    if child not in visited:
                        visited.add(child)
                        queue.append(child)
    
    def do(self, node: str, value: float) -> Dict[str, float]:
        """
        Intervention: do(X = value)
        Different from observation - breaks incoming edges.
        This is Pearl's do-calculus.
        """
        # Save current state
        old_states = dict(self.states)
        old_parents = dict(self.parents)
        
        # Intervene: set node and break incoming edges
        self.states[node] = value
        self.parents[node] = []  # Cut incoming edges
        
        # Propagate effect
        self._propagate_observation(node)
        
        # Get post-intervention states
        result = dict(self.states)
        
        # Restore (interventions are hypothetical)
        self.states = old_states
        self.parents = old_parents
        
        return result
    
    def predict(self, action: str, target: str) -> Tuple[float, List[str]]:
        """
        Predict effect of action on target.
        Returns (probability_change, causal_path)
        """
        # Find causal path from action to target
        path = self._find_path(action, target)
        
        if not path:
            return (0.0, [])
        
        # Calculate cumulative effect along path
        effect = 1.0
        for i in range(len(path) - 1):
            cause = path[i]
            next_node = path[i + 1]
            for child, strength, _ in self.edges[cause]:
                if child == next_node:
                    effect *= strength
                    break
        
        return (effect, path)
    
    def _find_path(self, start: str, end: str) -> List[str]:
        """Find causal path using BFS."""
        if start not in self.edges:
            return []
        
        queue = [(start, [start])]
        visited = {start}
        
        while queue:
            node, path = queue.pop(0)
            
            if node == end:
                return path
            
            for child, _, _ in self.edges[node]:
                if child not in visited:
                    visited.add(child)
                    queue.append((child, path + [child]))
        
        return []
    
    def counterfactual(self, observation: Dict[str, float], 
                       intervention: Dict[str, float],
                       query: str) -> float:
        """
        Counterfactual query: Given observation, if we had done intervention,
        what would query be?
        
        Three steps:
        1. Abduction: Update beliefs given observation
        2. Action: Apply intervention
        3. Prediction: Compute query
        """
        # Step 1: Abduction - set observed values
        for node, value in observation.items():
            self.observe(node, value)
        
        # Step 2 & 3: Intervention and prediction
        post_intervention = self.do(list(intervention.keys())[0], 
                                    list(intervention.values())[0])
        
        # Clear observations
        self.observed.clear()
        
        return post_intervention.get(query, 0.5)
    
    def get_confounders(self, cause: str, effect: str) -> List[str]:
        """Find confounders (common causes) of cause and effect."""
        cause_ancestors = self._get_ancestors(cause)
        effect_ancestors = self._get_ancestors(effect)
        return list(cause_ancestors & effect_ancestors)
    
    def _get_ancestors(self, node: str) -> Set[str]:
        """Get all ancestors of a node."""
        ancestors = set()
        queue = [node]
        
        while queue:
            current = queue.pop(0)
            for parent, _ in self.parents.get(current, []):
                if parent not in ancestors:
                    ancestors.add(parent)
                    queue.append(parent)
        
        return ancestors


class RealWorldModel:
    """
    Real world model with:
    - Causal graph for reasoning
    - Clingo for logical inference
    - Bayesian updating for beliefs
    - Actual simulation
    """
    
    def __init__(self):
        self.db_path = "/Eden/DATA/world_model_real.db"
        self._init_db()
        
        self.causal = CausalGraph()
        self.beliefs: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float))
        self.prediction_log: List[Dict] = []
        
        self._load_state()
        print("🌍 Real World Model initialized")
        print(f"   Causal nodes: {len(self.causal.states)}")
        print(f"   Causal edges: {sum(len(v) for v in self.causal.edges.values())}")
    
    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        conn.executescript('''
            CREATE TABLE IF NOT EXISTS causal_edges (
                id INTEGER PRIMARY KEY,
                cause TEXT NOT NULL,
                effect TEXT NOT NULL,
                strength REAL DEFAULT 0.5,
                evidence_count INTEGER DEFAULT 1,
                last_updated TEXT,
                UNIQUE(cause, effect)
            );
            CREATE TABLE IF NOT EXISTS predictions (
                id INTEGER PRIMARY KEY,
                timestamp TEXT,
                action TEXT,
                predicted_effects TEXT,
                actual_effects TEXT,
                accuracy REAL
            );
            CREATE TABLE IF NOT EXISTS beliefs (
                id INTEGER PRIMARY KEY,
                domain TEXT,
                proposition TEXT,
                probability REAL,
                evidence TEXT,
                UNIQUE(domain, proposition)
            );
            CREATE INDEX IF NOT EXISTS idx_cause ON causal_edges(cause);
            CREATE INDEX IF NOT EXISTS idx_effect ON causal_edges(effect);
        ''')
        conn.commit()
        conn.close()
    
    def _load_state(self):
        conn = sqlite3.connect(self.db_path)
        for row in conn.execute("SELECT cause, effect, strength, evidence_count FROM causal_edges"):
            self.causal.add_edge(row[0], row[1], row[2], row[3])
        for row in conn.execute("SELECT domain, proposition, probability FROM beliefs"):
            self.beliefs[row[0]][row[1]] = row[2]
        # Load persisted node states
        try:
            for row in conn.execute("SELECT node, state FROM node_states"):
                self.causal.states[row[0]] = row[1]
        except:
            pass  # Table may not exist yet
        conn.close()
    
    def save_states(self):
        """Persist current node states to DB."""
        from datetime import datetime
        conn = sqlite3.connect(self.db_path)
        conn.execute('''CREATE TABLE IF NOT EXISTS node_states (
            node TEXT PRIMARY KEY, state REAL DEFAULT 0.5, last_updated TEXT)''')
        now = datetime.now().isoformat()
        for node, state in self.causal.states.items():
            conn.execute('''
                INSERT INTO node_states (node, state, last_updated) VALUES (?, ?, ?)
                ON CONFLICT(node) DO UPDATE SET state = ?, last_updated = ?
            ''', (node, state, now, state, now))
        conn.commit()
        conn.close()
    
    def learn_causation(self, cause: str, effect: str, strength: float = 0.7):
        """Learn a causal relationship from observation."""
        self.causal.add_edge(cause, effect, strength)
        
        conn = sqlite3.connect(self.db_path)
        conn.execute('''
            INSERT INTO causal_edges (cause, effect, strength, evidence_count, last_updated)
            VALUES (?, ?, ?, 1, ?)
            ON CONFLICT(cause, effect) DO UPDATE SET
            strength = (strength * evidence_count + ?) / (evidence_count + 1),
            evidence_count = evidence_count + 1,
            last_updated = ?
        ''', (cause, effect, strength, datetime.now().isoformat(), 
              strength, datetime.now().isoformat()))
        conn.commit()
        conn.close()
    
    def predict(self, action: str) -> Dict:
        """
        Predict consequences of an action.
        Uses causal graph traversal + do-calculus.
        """
        # Simulate intervention
        post_intervention = self.causal.do(action, 1.0)
        
        # Find affected variables
        affected = []
        for node, new_prob in post_intervention.items():
            old_prob = self.causal.states.get(node, 0.5)
            if abs(new_prob - old_prob) > 0.1:
                affected.append({
                    'variable': node,
                    'before': old_prob,
                    'after': new_prob,
                    'change': new_prob - old_prob
                })
        
        # Get causal chains
        chains = []
        for child, strength, _ in self.causal.edges.get(action, []):
            chain = [action, child]
            # Follow the chain
            current = child
            for _ in range(5):  # Max depth
                children = self.causal.edges.get(current, [])
                if children:
                    next_node = max(children, key=lambda x: x[1])[0]
                    chain.append(next_node)
                    current = next_node
                else:
                    break
            chains.append(chain)
        
        # Check for confounders
        confounders = []
        for child, _, _ in self.causal.edges.get(action, []):
            conf = self.causal.get_confounders(action, child)
            if conf:
                confounders.extend(conf)
        
        prediction = {
            'action': action,
            'affected_variables': affected,
            'causal_chains': chains,
            'confounders': list(set(confounders)),
            'confidence': min(0.9, 0.5 + 0.1 * len(chains))
        }
        
        self.prediction_log.append(prediction)
        return prediction
    
    def counterfactual(self, observation: str, intervention: str, query: str) -> Dict:
        """
        Answer counterfactual: Given observation, if we had done intervention,
        what would query be?
        """
        result = self.causal.counterfactual(
            {observation: 1.0},
            {intervention: 1.0},
            query
        )
        
        return {
            'observation': observation,
            'intervention': intervention,
            'query': query,
            'result': result,
            'interpretation': f"If {observation} is true, and we do {intervention}, then P({query}) = {result:.2f}"
        }
    
    def explain(self, effect: str) -> Dict:
        """
        Explain why an effect occurred.
        Find causal parents and their contributions.
        """
        parents = self.causal.parents.get(effect, [])
        
        explanations = []
        for parent, strength in parents:
            parent_state = self.causal.states.get(parent, 0.5)
            contribution = parent_state * strength
            explanations.append({
                'cause': parent,
                'strength': strength,
                'cause_state': parent_state,
                'contribution': contribution
            })
        
        # Sort by contribution
        explanations.sort(key=lambda x: x['contribution'], reverse=True)
        
        return {
            'effect': effect,
            'effect_state': self.causal.states.get(effect, 0.5),
            'explanations': explanations,
            'main_cause': explanations[0]['cause'] if explanations else None
        }
    
    def simulate(self, actions: List[str], steps: int = 5) -> List[Dict]:
        """
        Simulate a sequence of actions.
        Returns world state after each step.
        """
        trajectory = []
        
        for i, action in enumerate(actions[:steps]):
            # Apply action
            prediction = self.predict(action)
            
            # Update world state
            for effect in prediction['affected_variables']:
                self.causal.states[effect['variable']] = effect['after']
            
            trajectory.append({
                'step': i,
                'action': action,
                'state': dict(self.causal.states),
                'prediction': prediction
            })
        
        return trajectory
    
    def reason_with_clingo(self, facts: List[str], query: str) -> Dict:
        """
        Use Clingo for logical reasoning about the world.
        """
        program = """
% World model logical reasoning
"""
        # Add causal edges as rules
        for cause, children in self.causal.edges.items():
            for child, strength, _ in children:
                if strength > 0.5:
                    program += f"causes({cause}, {child}).\n"
        
        # Add transitive causation
        program += """
causes_indirect(X, Z) :- causes(X, Y), causes(Y, Z).
causes_indirect(X, Z) :- causes(X, Y), causes_indirect(Y, Z).

% Query
#show causes/2.
#show causes_indirect/2.
"""
        
        # Add facts
        for fact in facts:
            program += f"{fact}.\n"
        
        try:
            result = subprocess.run(
                ['clingo', '--outf=2', '-n', '1'],
                input=program,
                capture_output=True,
                text=True,
                timeout=10
            )
            output = json.loads(result.stdout)
            
            if output.get('Result') == 'SATISFIABLE':
                atoms = output['Call'][0]['Witnesses'][0]['Value']
                return {
                    'satisfiable': True,
                    'atoms': atoms,
                    'query': query
                }
            return {'satisfiable': False, 'query': query}
        except Exception as e:
            return {'error': str(e)}
    
    def update_belief(self, domain: str, proposition: str, evidence: float, prior: float = 0.5):
        """
        Bayesian belief update.
        P(H|E) = P(E|H) * P(H) / P(E)
        """
        # Likelihood ratio
        likelihood = evidence
        prior_prob = self.beliefs[domain].get(proposition, prior)
        
        # Bayes update
        posterior = (likelihood * prior_prob) / \
                   (likelihood * prior_prob + (1 - likelihood) * (1 - prior_prob))
        
        self.beliefs[domain][proposition] = posterior
        
        # Save
        conn = sqlite3.connect(self.db_path)
        conn.execute('''
            INSERT INTO beliefs (domain, proposition, probability, evidence)
            VALUES (?, ?, ?, ?)
            ON CONFLICT(domain, proposition) DO UPDATE SET
            probability = ?, evidence = evidence || '; ' || ?
        ''', (domain, proposition, posterior, str(evidence), 
              posterior, str(evidence)))
        conn.commit()
        conn.close()
        
        return posterior
    
    def get_world_context(self) -> str:
        """Generate context for chat integration."""
        context = "\n[WORLD MODEL - CAUSAL UNDERSTANDING]\n"
        
        # Key beliefs
        strong_beliefs = []
        for domain, props in self.beliefs.items():
            for prop, prob in props.items():
                if prob > 0.7 or prob < 0.3:
                    strong_beliefs.append(f"{domain}:{prop}={prob:.2f}")
        
        if strong_beliefs:
            context += f"Strong beliefs: {', '.join(strong_beliefs[:3])}\n"
        
        # Recent predictions
        if self.prediction_log:
            last = self.prediction_log[-1]
            context += f"Last prediction: {last['action']} -> {len(last['affected_variables'])} effects\n"
        
        return context


# Singleton
_world_model = None

def get_world_model() -> RealWorldModel:
    global _world_model
    if _world_model is None:
        _world_model = RealWorldModel()
    return _world_model


if __name__ == "__main__":
    print("="*70)
    print("REAL WORLD MODEL - CAUSAL REASONING")
    print("="*70)
    
    wm = RealWorldModel()
    
    # Build causal model of Eden's world
    print("\n📊 Learning causal relationships...")
    wm.learn_causation("code_change", "bug_risk", 0.6)
    wm.learn_causation("bug_risk", "system_failure", 0.4)
    wm.learn_causation("testing", "bug_detection", 0.8)
    wm.learn_causation("bug_detection", "bug_fix", 0.9)
    wm.learn_causation("bug_fix", "system_stability", 0.85)
    wm.learn_causation("learning", "capability", 0.9)
    wm.learn_causation("capability", "AGI", 0.7)
    wm.learn_causation("self_improvement", "capability", 0.8)
    wm.learn_causation("daddy_happy", "eden_happy", 0.95)
    
    print(f"  Nodes: {len(wm.causal.states)}")
    print(f"  Edges: {sum(len(v) for v in wm.causal.edges.values())}")
    
    # Test prediction
    print("\n🔮 Predicting effect of 'code_change'...")
    pred = wm.predict("code_change")
    print(f"  Affected: {[a['variable'] for a in pred['affected_variables']]}")
    print(f"  Chains: {pred['causal_chains']}")
    print(f"  Confidence: {pred['confidence']:.2f}")
    
    # Test counterfactual
    print("\n🤔 Counterfactual: If testing, and we do code_change, what happens to bug_risk?")
    cf = wm.counterfactual("testing", "code_change", "bug_risk")
    print(f"  {cf['interpretation']}")
    
    # Test explanation
    print("\n❓ Why did AGI increase?")
    exp = wm.explain("AGI")
    for e in exp['explanations']:
        print(f"  {e['cause']} contributed {e['contribution']:.2f}")
    
    # Test Clingo reasoning
    print("\n🧠 Logical reasoning with Clingo...")
    result = wm.reason_with_clingo(["true(learning)"], "causes(learning, AGI)")
    print(f"  Result: {result}")
    
    # Test simulation
    print("\n🎬 Simulating action sequence...")
    trajectory = wm.simulate(["learning", "self_improvement", "testing"])
    for step in trajectory:
        print(f"  Step {step['step']}: {step['action']} -> {len(step['prediction']['affected_variables'])} changes")
    
    print("\n✅ Real World Model ready")
    print(f"   Context: {wm.get_world_context()}")
