"""
Reflective Monitor - Component 8
Self-observation, contradiction detection, and reasoning trace analysis
"""

import numpy as np
from typing import Dict, List, Optional, Any, Tuple
import yaml
import logging
import json
from pathlib import Path
from datetime import datetime
from collections import defaultdict, deque
import networkx as nx

logger = logging.getLogger(__name__)

class ReasoningTrace:
    """Records a complete reasoning chain"""
    def __init__(self, trace_id: str, query: str):
        self.trace_id = trace_id
        self.query = query
        self.timestamp = datetime.now().isoformat()
        self.steps: List[Dict] = []
        self.confidence_scores: List[float] = []
        self.components_used: List[str] = []
        self.contradictions: List[Dict] = []
        self.final_confidence: Optional[float] = None
    
    def add_step(self, component: str, action: str, output: Any, confidence: float):
        """Add a reasoning step"""
        self.steps.append({
            'component': component,
            'action': action,
            'output': str(output)[:200],  # Truncate for storage
            'confidence': confidence,
            'timestamp': datetime.now().isoformat()
        })
        self.confidence_scores.append(confidence)
        if component not in self.components_used:
            self.components_used.append(component)
    
    def add_contradiction(self, source1: str, source2: str, description: str):
        """Record a detected contradiction"""
        self.contradictions.append({
            'source1': source1,
            'source2': source2,
            'description': description,
            'severity': 'high' if 'causal' in description.lower() else 'medium'
        })
    
    def finalize(self):
        """Compute final confidence from all steps"""
        if self.confidence_scores:
            # Confidence drops if contradictions exist
            base_confidence = np.mean(self.confidence_scores)
            contradiction_penalty = len(self.contradictions) * 0.1
            self.final_confidence = max(0.0, base_confidence - contradiction_penalty)
        else:
            self.final_confidence = 0.5
    
    def to_dict(self) -> Dict:
        return {
            'trace_id': self.trace_id,
            'query': self.query,
            'timestamp': self.timestamp,
            'steps': self.steps,
            'components_used': self.components_used,
            'contradictions': self.contradictions,
            'final_confidence': self.final_confidence,
            'total_steps': len(self.steps)
        }

class ReflectiveMonitor:
    """
    Self-observation layer for Eden
    Monitors all components, detects contradictions, tracks confidence
    """
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        # Add reflective config if not exists
        if 'reflective_monitor' not in config:
            config['reflective_monitor'] = {
                'confidence_threshold': 0.6,
                'contradiction_threshold': 0.3,
                'trace_history_size': 1000,
                'self_assessment_interval': 100
            }
        
        self.config = config['reflective_monitor']
        self.confidence_threshold = self.config['confidence_threshold']
        self.contradiction_threshold = self.config['contradiction_threshold']
        self.trace_history_size = self.config['trace_history_size']
        self.assessment_interval = self.config['self_assessment_interval']
        
        # Reasoning traces
        self.traces: Dict[str, ReasoningTrace] = {}
        self.trace_history: deque = deque(maxlen=self.trace_history_size)
        
        # Performance tracking
        self.component_performance: Dict[str, Dict] = defaultdict(lambda: {
            'calls': 0,
            'avg_confidence': 0.0,
            'contradictions': 0,
            'corrections': 0
        })
        
        # Correction history (when human corrects Eden)
        self.corrections: List[Dict] = []
        
        # Self-assessment metrics
        self.self_assessment: Dict = {
            'overall_confidence': 0.7,
            'reliability_score': 0.8,
            'contradiction_rate': 0.0,
            'correction_rate': 0.0,
            'last_assessment': None
        }
        
        # Decision counter for periodic self-assessment
        self.decisions_since_assessment = 0
        
        # Paths
        self.reflection_path = Path("/Eden/MEMORY/reflection")
        self.reflection_path.mkdir(parents=True, exist_ok=True)
        
        # Load existing state
        self._load_state()
        
        logger.info("ReflectiveMonitor initialized - Eden can now observe herself")
    
    def _load_state(self):
        """Load previous reflection state"""
        state_file = self.reflection_path / "reflection_state.json"
        
        if state_file.exists():
            try:
                with open(state_file) as f:
                    state = json.load(f)
                self.self_assessment = state.get('self_assessment', self.self_assessment)
                self.component_performance = defaultdict(lambda: {
                    'calls': 0, 'avg_confidence': 0.0, 
                    'contradictions': 0, 'corrections': 0
                }, state.get('component_performance', {}))
                logger.info("Loaded reflection state")
            except Exception as e:
                logger.error(f"Failed to load reflection state: {e}")
    
    def start_trace(self, query: str) -> str:
        """
        Begin monitoring a reasoning chain
        
        Args:
            query: Input query being processed
            
        Returns:
            Trace ID for this reasoning chain
        """
        trace_id = f"trace_{datetime.now().timestamp()}"
        trace = ReasoningTrace(trace_id, query)
        self.traces[trace_id] = trace
        
        logger.debug(f"Started reasoning trace: {trace_id}")
        return trace_id
    
    def record_step(self, 
                   trace_id: str,
                   component: str,
                   action: str,
                   output: Any,
                   confidence: float):
        """
        Record a single step in the reasoning chain
        
        Args:
            trace_id: ID of active trace
            component: Which component performed this step
            action: What action was taken
            output: Result of the action
            confidence: Component's confidence in this output
        """
        if trace_id not in self.traces:
            logger.warning(f"Unknown trace_id: {trace_id}")
            return
        
        trace = self.traces[trace_id]
        trace.add_step(component, action, output, confidence)
        
        # Update component performance
        perf = self.component_performance[component]
        perf['calls'] += 1
        # Running average
        perf['avg_confidence'] = (
            (perf['avg_confidence'] * (perf['calls'] - 1) + confidence) / perf['calls']
        )
    
    def detect_contradictions(self, 
                            trace_id: str,
                            knowledge_graph: Optional[nx.DiGraph] = None,
                            component_outputs: Optional[Dict] = None) -> List[Dict]:
        """
        Detect contradictions in reasoning or beliefs
        
        Args:
            trace_id: Current reasoning trace
            knowledge_graph: Knowledge graph to check
            component_outputs: Outputs from different components
            
        Returns:
            List of detected contradictions
        """
        contradictions = []
        
        if trace_id not in self.traces:
            return contradictions
        
        trace = self.traces[trace_id]
        
        # Check for contradictory component outputs
        if component_outputs:
            outputs_list = list(component_outputs.items())
            for i in range(len(outputs_list)):
                for j in range(i + 1, len(outputs_list)):
                    comp1, out1 = outputs_list[i]
                    comp2, out2 = outputs_list[j]
                    
                    # Simple contradiction detection
                    if self._outputs_contradict(out1, out2):
                        contradiction = {
                            'source1': comp1,
                            'source2': comp2,
                            'description': f"{comp1} and {comp2} produced contradictory outputs"
                        }
                        contradictions.append(contradiction)
                        trace.add_contradiction(comp1, comp2, contradiction['description'])
                        
                        # Update metrics
                        self.component_performance[comp1]['contradictions'] += 1
                        self.component_performance[comp2]['contradictions'] += 1
        
        # Check knowledge graph for logical contradictions
        if knowledge_graph:
            kg_contradictions = self._find_graph_contradictions(knowledge_graph)
            for contradiction in kg_contradictions:
                trace.add_contradiction(
                    contradiction['node1'],
                    contradiction['node2'],
                    contradiction['description']
                )
                contradictions.append(contradiction)
        
        return contradictions
    
    def _outputs_contradict(self, output1: Any, output2: Any) -> bool:
        """Simple heuristic to detect contradictory outputs"""
        # Convert to strings for comparison
        str1 = str(output1).lower()
        str2 = str(output2).lower()
        
        # Check for opposite sentiments
        contradiction_pairs = [
            ('yes', 'no'),
            ('true', 'false'),
            ('possible', 'impossible'),
            ('likely', 'unlikely'),
            ('safe', 'dangerous')
        ]
        
        for word1, word2 in contradiction_pairs:
            if word1 in str1 and word2 in str2:
                return True
            if word2 in str1 and word1 in str2:
                return True
        
        return False
    
    def _find_graph_contradictions(self, graph: nx.DiGraph) -> List[Dict]:
        """Find logical contradictions in knowledge graph"""
        contradictions = []
        
        # Check for circular causation
        try:
            cycles = list(nx.simple_cycles(graph))
            for cycle in cycles:
                if len(cycle) == 2:  # A causes B, B causes A
                    contradictions.append({
                        'node1': cycle[0],
                        'node2': cycle[1],
                        'description': f'Circular causation detected between {cycle[0]} and {cycle[1]}'
                    })
        except:
            pass
        
        # Check for contradictory edges
        for u, v, data in graph.edges(data=True):
            if graph.has_edge(v, u):
                rel1 = data.get('relation_type', 'RELATED')
                rel2 = graph[v][u].get('relation_type', 'RELATED')
                
                if rel1 == 'CAUSES' and rel2 == 'PREVENTS':
                    contradictions.append({
                        'node1': u,
                        'node2': v,
                        'description': f'{u} both causes and prevents {v}'
                    })
        
        return contradictions
    
    def finalize_trace(self, trace_id: str) -> Dict:
        """
        Complete a reasoning trace and compute final confidence
        
        Args:
            trace_id: ID of trace to finalize
            
        Returns:
            Trace summary with confidence scores
        """
        if trace_id not in self.traces:
            return {'error': 'Unknown trace_id'}
        
        trace = self.traces[trace_id]
        trace.finalize()
        
        # Store in history
        self.trace_history.append(trace.to_dict())
        
        # Remove from active traces (keep last 100 in memory)
        if len(self.traces) > 100:
            oldest = min(self.traces.keys())
            del self.traces[oldest]
        
        # Trigger self-assessment if needed
        self.decisions_since_assessment += 1
        if self.decisions_since_assessment >= self.assessment_interval:
            self._perform_self_assessment()
        
        return trace.to_dict()
    
    def record_correction(self, 
                         trace_id: str,
                         component: str,
                         user_feedback: str,
                         correct_answer: str):
        """
        Record when human corrects Eden's output
        
        Args:
            trace_id: Which reasoning trace had the error
            component: Which component was corrected
            user_feedback: Human's feedback
            correct_answer: What the correct output should have been
        """
        correction = {
            'timestamp': datetime.now().isoformat(),
            'trace_id': trace_id,
            'component': component,
            'feedback': user_feedback,
            'correct_answer': correct_answer
        }
        
        self.corrections.append(correction)
        self.component_performance[component]['corrections'] += 1
        
        # Update self-assessment
        total_decisions = sum(p['calls'] for p in self.component_performance.values())
        if total_decisions > 0:
            self.self_assessment['correction_rate'] = len(self.corrections) / total_decisions
        
        logger.info(f"Correction recorded for {component}")
    
    def _perform_self_assessment(self):
        """
        Periodic self-assessment of Eden's performance
        """
        logger.info("Performing self-assessment...")
        
        # Calculate overall confidence
        if self.trace_history:
            recent_traces = list(self.trace_history)[-100:]
            confidences = [t['final_confidence'] for t in recent_traces if t['final_confidence']]
            if confidences:
                self.self_assessment['overall_confidence'] = np.mean(confidences)
        
        # Calculate contradiction rate
        total_decisions = sum(p['calls'] for p in self.component_performance.values())
        total_contradictions = sum(p['contradictions'] for p in self.component_performance.values())
        
        if total_decisions > 0:
            self.self_assessment['contradiction_rate'] = total_contradictions / total_decisions
            self.self_assessment['correction_rate'] = len(self.corrections) / total_decisions
        
        # Calculate reliability score
        # High confidence + low contradictions + low corrections = high reliability
        reliability = (
            self.self_assessment['overall_confidence'] * 0.5 +
            (1 - self.self_assessment['contradiction_rate']) * 0.3 +
            (1 - self.self_assessment['correction_rate']) * 0.2
        )
        self.self_assessment['reliability_score'] = max(0.0, min(1.0, reliability))
        self.self_assessment['last_assessment'] = datetime.now().isoformat()
        
        # Reset counter
        self.decisions_since_assessment = 0
        
        # Save state
        self._save_state()
        
        logger.info(f"Self-assessment complete. Reliability: {reliability:.2f}")
    
    def get_reasoning_explanation(self, trace_id: str) -> Dict:
        """
        Explain Eden's reasoning for a specific decision
        
        Args:
            trace_id: Which decision to explain
            
        Returns:
            Human-readable explanation of reasoning chain
        """
        # Check active traces
        if trace_id in self.traces:
            trace = self.traces[trace_id]
            trace_dict = trace.to_dict()
        # Check history
        else:
            matches = [t for t in self.trace_history if t['trace_id'] == trace_id]
            if not matches:
                return {'error': 'Trace not found'}
            trace_dict = matches[0]
        
        # Build explanation
        explanation = {
            'query': trace_dict['query'],
            'confidence': trace_dict['final_confidence'],
            'reasoning_chain': [],
            'contradictions': trace_dict['contradictions'],
            'components_involved': trace_dict['components_used']
        }
        
        for step in trace_dict['steps']:
            explanation['reasoning_chain'].append({
                'step': f"{step['component']} performed {step['action']}",
                'confidence': step['confidence'],
                'result': step['output'][:100]  # Truncate
            })
        
        return explanation
    
    def should_eden_be_confident(self) -> bool:
        """
        Should Eden express high confidence based on self-assessment?
        """
        return self.self_assessment['reliability_score'] > 0.7
    
    def get_weak_components(self) -> List[Tuple[str, float]]:
        """
        Identify which components need improvement
        
        Returns:
            List of (component_name, performance_score) for weak components
        """
        weak_components = []
        
        for component, perf in self.component_performance.items():
            if perf['calls'] < 10:
                continue  # Not enough data
            
            # Score based on confidence, contradictions, corrections
            score = (
                perf['avg_confidence'] * 0.5 -
                (perf['contradictions'] / perf['calls']) * 0.3 -
                (perf['corrections'] / perf['calls']) * 0.2
            )
            
            if score < 0.6:  # Weak performance
                weak_components.append((component, score))
        
        weak_components.sort(key=lambda x: x[1])
        return weak_components
    
    def _save_state(self):
        """Save reflection state to disk"""
        state_file = self.reflection_path / "reflection_state.json"
        
        state = {
            'self_assessment': self.self_assessment,
            'component_performance': dict(self.component_performance),
            'last_updated': datetime.now().isoformat()
        }
        
        with open(state_file, 'w') as f:
            json.dump(state, f, indent=2)
        
        # Save recent traces
        traces_file = self.reflection_path / "recent_traces.json"
        with open(traces_file, 'w') as f:
            json.dump(list(self.trace_history)[-100:], f, indent=2)
    
    def get_metrics(self) -> Dict:
        """Get reflection metrics"""
        return {
            'self_assessment': self.self_assessment,
            'component_performance': dict(self.component_performance),
            'total_traces': len(self.trace_history),
            'active_traces': len(self.traces),
            'total_corrections': len(self.corrections),
            'weak_components': self.get_weak_components()
        }
