"""
Causal Reasoning Scaffold
Hypothesis generation, intervention design, and causal inference
"""

import networkx as nx
import numpy as np
from typing import List, Dict, Optional, Any
import yaml
import logging
import pickle
from pathlib import Path
from datetime import datetime

logger = logging.getLogger(__name__)

class CausalScaffold:
    """Builds and maintains causal models of the world"""
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        # Add causal config if not exists
        if 'causal' not in config:
            config['causal'] = {
                'intervention_confidence_min': 0.7,
                'max_graph_size': 500,
                'counterfactual_samples': 100
            }
        
        self.config = config['causal']
        self.intervention_confidence_min = self.config['intervention_confidence_min']
        self.max_graph_size = self.config['max_graph_size']
        self.counterfactual_samples = self.config['counterfactual_samples']
        
        # Causal graph (DAG)
        self.causal_graph = nx.DiGraph()
        
        # Observation history
        self.observations: List[Dict] = []
        
        # Intervention history
        self.interventions: List[Dict] = []
        
        # Paths
        self.model_path = Path("/Eden/MEMORY/causal_models")
        self.model_path.mkdir(parents=True, exist_ok=True)
        
        # Load existing model
        self._load_model()
        
        logger.info("CausalScaffold initialized")
    
    def _load_model(self):
        """Load existing causal model from disk"""
        model_file = self.model_path / "current_model.pkl"
        
        if model_file.exists():
            try:
                with open(model_file, 'rb') as f:
                    data = pickle.load(f)
                self.causal_graph = data.get('graph', nx.DiGraph())
                self.observations = data.get('observations', [])
                self.interventions = data.get('interventions', [])
                logger.info(f"Loaded causal model: {len(self.causal_graph.nodes)} variables")
            except Exception as e:
                logger.error(f"Failed to load causal model: {e}")
    
    def add_observation(self, variables: Dict[str, Any], context: Optional[str] = None):
        """
        Add an observed data point
        
        Args:
            variables: Dict of variable_name -> value
            context: Optional description
        """
        observation = {
            'timestamp': datetime.now().isoformat(),
            'variables': variables,
            'context': context
        }
        
        self.observations.append(observation)
        
        # Add variables to graph if new
        for var_name in variables.keys():
            if not self.causal_graph.has_node(var_name):
                self.causal_graph.add_node(
                    var_name,
                    type='observed',
                    first_seen=datetime.now().isoformat()
                )
        
        # Periodically update causal structure
        if len(self.observations) % 20 == 0:
            self._update_causal_structure()
        
        logger.debug(f"Observation added: {len(variables)} variables")
    
    def generate_hypotheses(self, 
                           target_variable: str,
                           candidate_causes: Optional[List[str]] = None) -> List[Dict]:
        """
        Generate causal hypotheses for a target variable
        
        Args:
            target_variable: Variable to explain (effect)
            candidate_causes: Optional list of potential causes
            
        Returns:
            List of hypotheses ranked by plausibility
        """
        if target_variable not in self.causal_graph:
            logger.warning(f"Target {target_variable} not in causal graph")
            return []
        
        if candidate_causes is None:
            candidate_causes = [
                n for n in self.causal_graph.nodes()
                if n != target_variable
            ]
        
        hypotheses = []
        
        for cause in candidate_causes:
            if cause not in self.causal_graph:
                continue
            
            existing_edge = self.causal_graph.has_edge(cause, target_variable)
            
            # Compute plausibility
            correlation = self._compute_correlation(cause, target_variable)
            temporal_order = 0.5  # Simplified
            structural_plausibility = self._check_structural_plausibility(cause, target_variable)
            
            plausibility = (
                0.4 * correlation +
                0.3 * temporal_order +
                0.3 * structural_plausibility
            )
            
            hypothesis = {
                'cause': cause,
                'effect': target_variable,
                'plausibility': plausibility,
                'correlation': correlation,
                'already_in_graph': existing_edge,
                'suggested_intervention': f"Manipulate {cause}, observe {target_variable}"
            }
            
            hypotheses.append(hypothesis)
        
        hypotheses.sort(key=lambda h: h['plausibility'], reverse=True)
        
        logger.info(f"Generated {len(hypotheses)} hypotheses for {target_variable}")
        return hypotheses
    
    def _compute_correlation(self, var1: str, var2: str) -> float:
        """Compute correlation between two variables from observations"""
        if len(self.observations) < 2:
            return 0.0
        
        values1 = []
        values2 = []
        
        for obs in self.observations:
            if var1 in obs['variables'] and var2 in obs['variables']:
                v1 = obs['variables'][var1]
                v2 = obs['variables'][var2]
                
                # Convert to numeric
                try:
                    v1 = float(v1) if not isinstance(v1, bool) else (1.0 if v1 else 0.0)
                    v2 = float(v2) if not isinstance(v2, bool) else (1.0 if v2 else 0.0)
                    values1.append(v1)
                    values2.append(v2)
                except:
                    pass
        
        if len(values1) < 2:
            return 0.0
        
        try:
            corr = float(np.corrcoef(values1, values2)[0, 1])
            return abs(corr)
        except:
            return 0.0
    
    def _check_structural_plausibility(self, cause: str, effect: str) -> float:
        """Check if adding this edge would create cycles"""
        temp_graph = self.causal_graph.copy()
        temp_graph.add_edge(cause, effect)
        
        if nx.is_directed_acyclic_graph(temp_graph):
            return 1.0
        else:
            return 0.0  # Would create cycle
    
    def design_intervention(self, hypothesis: Dict) -> Dict:
        """
        Design an intervention to test a causal hypothesis
        
        Args:
            hypothesis: Hypothesis dict from generate_hypotheses()
            
        Returns:
            Intervention plan
        """
        cause = hypothesis['cause']
        effect = hypothesis['effect']
        
        intervention_plan = {
            'type': 'controlled_intervention',
            'hypothesis': hypothesis,
            'cause_variable': cause,
            'effect_variable': effect,
            'steps': [
                {
                    'action': f"Set {cause} to HIGH",
                    'observe': [effect],
                    'expected': f"{effect} should increase if causal"
                },
                {
                    'action': f"Set {cause} to LOW",
                    'observe': [effect],
                    'expected': f"{effect} should decrease if causal"
                },
                {
                    'action': "Compare outcomes",
                    'observe': [],
                    'expected': "Determine if causal link exists"
                }
            ],
            'confidence_required': self.intervention_confidence_min,
            'safety_check': self._assess_intervention_safety(cause, effect)
        }
        
        return intervention_plan
    
    def _assess_intervention_safety(self, cause: str, effect: str) -> Dict:
        """Assess safety of manipulating a variable"""
        descendants = set()
        if self.causal_graph.has_node(cause):
            try:
                descendants = nx.descendants(self.causal_graph, cause)
            except:
                pass
        
        safety = {
            'safe': len(descendants) < 10,
            'affected_variables': list(descendants),
            'risk_level': 'low' if len(descendants) < 5 else 'medium' if len(descendants) < 10 else 'high',
            'recommendation': 'Safe to proceed' if len(descendants) < 10 else 'Proceed with caution'
        }
        
        return safety
    
    def simulate_counterfactual(self,
                               target_variable: str,
                               intervention: Dict[str, Any],
                               num_samples: int = None) -> List[Dict]:
        """
        Simulate counterfactual: "What if X had been different?"
        
        Args:
            target_variable: Variable to predict
            intervention: Dict of variable -> counterfactual value
            num_samples: Number of simulation samples
            
        Returns:
            List of predicted outcomes
        """
        if num_samples is None:
            num_samples = self.counterfactual_samples
        
        outcomes = []
        
        # Monte Carlo simulation
        for _ in range(num_samples):
            state = intervention.copy()
            
            # Propagate through causal graph
            try:
                topo_order = list(nx.topological_sort(self.causal_graph))
            except:
                topo_order = list(self.causal_graph.nodes())
            
            for node in topo_order:
                if node in state:
                    continue
                
                parents = list(self.causal_graph.predecessors(node))
                
                if not parents:
                    # Root node
                    state[node] = self._sample_from_observations(node)
                else:
                    # Conditional on parents
                    parent_values = {p: state.get(p) for p in parents if p in state}
                    state[node] = self._sample_conditional(node, parent_values)
            
            outcomes.append({
                'sample': state,
                'target_value': state.get(target_variable),
                'intervention': intervention
            })
        
        logger.info(f"Simulated {len(outcomes)} counterfactual outcomes")
        return outcomes
    
    def _sample_from_observations(self, variable: str) -> Any:
        """Sample variable value from empirical distribution"""
        values = [
            obs['variables'].get(variable)
            for obs in self.observations
            if variable in obs['variables']
        ]
        
        if not values:
            return None
        
        return np.random.choice(values)
    
    def _sample_conditional(self, variable: str, parent_values: Dict) -> Any:
        """Sample variable conditioned on parent values"""
        matching_obs = []
        
        for obs in self.observations:
            match = all(
                obs['variables'].get(p) == v
                for p, v in parent_values.items()
                if v is not None
            )
            if match and variable in obs['variables']:
                matching_obs.append(obs['variables'][variable])
        
        if matching_obs:
            return np.random.choice(matching_obs)
        else:
            return self._sample_from_observations(variable)
    
    def record_intervention(self,
                           intervention_plan: Dict,
                           outcome: Dict,
                           success: bool):
        """
        Record results of an intervention experiment
        Updates causal graph based on evidence
        """
        cause = intervention_plan['cause_variable']
        effect = intervention_plan['effect_variable']
        
        intervention_record = {
            'timestamp': datetime.now().isoformat(),
            'hypothesis': intervention_plan['hypothesis'],
            'outcome': outcome,
            'success': success,
            'causal_link_confirmed': success
        }
        
        self.interventions.append(intervention_record)
        
        # Update causal graph
        if success:
            if self.causal_graph.has_edge(cause, effect):
                current_conf = self.causal_graph[cause][effect].get('confidence', 0.5)
                new_conf = min(0.95, current_conf + 0.1)
                self.causal_graph[cause][effect]['confidence'] = new_conf
            else:
                self.causal_graph.add_edge(
                    cause,
                    effect,
                    confidence=0.7,
                    evidence='intervention',
                    confirmed=datetime.now().isoformat()
                )
            
            logger.info(f"Causal link confirmed: {cause} → {effect}")
        else:
            if self.causal_graph.has_edge(cause, effect):
                current_conf = self.causal_graph[cause][effect].get('confidence', 0.5)
                new_conf = current_conf * 0.5
                
                if new_conf < 0.3:
                    self.causal_graph.remove_edge(cause, effect)
                    logger.info(f"Causal link rejected: {cause} ↛ {effect}")
                else:
                    self.causal_graph[cause][effect]['confidence'] = new_conf
        
        self._save_model()
    
    def _update_causal_structure(self):
        """Update causal graph structure from observations"""
        if len(self.observations) < 10:
            return
        
        logger.info("Updating causal structure...")
        
        # Simple heuristic: add edges for strong correlations
        variables = list(self.causal_graph.nodes())
        
        for i, var1 in enumerate(variables):
            for var2 in variables[i+1:]:
                corr = self._compute_correlation(var1, var2)
                
                if corr > 0.7:
                    if not self.causal_graph.has_edge(var1, var2) and not self.causal_graph.has_edge(var2, var1):
                        temp_graph = self.causal_graph.copy()
                        temp_graph.add_edge(var1, var2)
                        
                        if nx.is_directed_acyclic_graph(temp_graph):
                            self.causal_graph.add_edge(
                                var1,
                                var2,
                                confidence=0.5,
                                evidence='correlation'
                            )
    
    def _save_model(self):
        """Save causal model to disk"""
        model_file = self.model_path / "current_model.pkl"
        
        data = {
            'graph': self.causal_graph,
            'observations': self.observations[-1000:],
            'interventions': self.interventions[-100:],
            'last_updated': datetime.now().isoformat()
        }
        
        with open(model_file, 'wb') as f:
            pickle.dump(data, f)
    
    def get_metrics(self) -> Dict:
        """Get causal reasoning metrics"""
        if len(self.causal_graph) == 0:
            return {
                'variables': 0,
                'causal_links': 0,
                'avg_confidence': 0.0,
                'observations': 0,
                'interventions_performed': 0,
                'success_rate': 0.0
            }
        
        edges_with_conf = [
            self.causal_graph[u][v].get('confidence', 0.5)
            for u, v in self.causal_graph.edges()
        ]
        
        successful = sum(1 for i in self.interventions if i.get('success', False))
        
        return {
            'variables': len(self.causal_graph.nodes),
            'causal_links': len(self.causal_graph.edges),
            'avg_confidence': float(np.mean(edges_with_conf)) if edges_with_conf else 0.0,
            'observations': len(self.observations),
            'interventions_performed': len(self.interventions),
            'success_rate': successful / len(self.interventions) if self.interventions else 0.0
        }
