"""
Curiosity Driver
Active exploration and intrinsic motivation system
"""

import numpy as np
from typing import List, Dict, Tuple, Optional
import yaml
import logging
import json
from pathlib import Path
from datetime import datetime, timedelta
from collections import deque
import networkx as nx

logger = logging.getLogger(__name__)

class CuriosityDriver:
    """Drives active exploration through intrinsic motivation"""
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        # Add curiosity config if not exists
        if 'curiosity' not in config:
            config['curiosity'] = {
                'uncertainty_weight': 0.4,
                'novelty_weight': 0.4,
                'progress_weight': 0.2,
                'exploration_budget': 100
            }
        
        self.config = config['curiosity']
        self.uncertainty_weight = self.config['uncertainty_weight']
        self.novelty_weight = self.config['novelty_weight']
        self.progress_weight = self.config['progress_weight']
        self.exploration_budget = self.config['exploration_budget']
        
        # State tracking
        self.explored_regions: Dict[str, Dict] = {}
        self.exploration_history = deque(maxlen=1000)
        self.learning_progress = deque(maxlen=100)
        
        # Paths
        self.log_path = Path("/Eden/MEMORY/curiosity_logs")
        self.log_path.mkdir(parents=True, exist_ok=True)
        
        # Load existing state
        self._load_state()
        
        # Daily exploration count
        self.today_explorations = 0
        self.last_reset_date = datetime.now().date()
        
        logger.info("CuriosityDriver initialized")
    
    def _load_state(self):
        """Load exploration history from disk"""
        state_file = self.log_path / "curiosity_state.json"
        
        if state_file.exists():
            try:
                with open(state_file) as f:
                    state = json.load(f)
                self.explored_regions = state.get('explored_regions', {})
                logger.info(f"Loaded {len(self.explored_regions)} explored regions")
            except Exception as e:
                logger.error(f"Failed to load curiosity state: {e}")
    
    def compute_curiosity(self, 
                         query: str,
                         context: Dict,
                         knowledge_graph: Optional[nx.DiGraph] = None) -> float:
        """
        Compute curiosity score for a query/topic
        
        Args:
            query: Topic or question to evaluate
            context: Additional context
            knowledge_graph: Current knowledge graph
            
        Returns:
            Curiosity score (0-1, higher = more curious)
        """
        # Uncertainty: How confident are we?
        uncertainty = self._estimate_uncertainty(query, context, knowledge_graph)
        
        # Novelty: How different is this?
        novelty = self._estimate_novelty(query, context)
        
        # Progress: Are we learning effectively?
        progress = self._estimate_progress(query, context)
        
        # Weighted combination
        curiosity = (
            self.uncertainty_weight * uncertainty +
            self.novelty_weight * novelty +
            self.progress_weight * progress
        )
        
        return float(np.clip(curiosity, 0.0, 1.0))
    
    def _estimate_uncertainty(self, 
                             query: str,
                             context: Dict,
                             knowledge_graph: Optional[nx.DiGraph]) -> float:
        """Estimate uncertainty about a topic"""
        topic_key = query.lower().strip()
        
        if topic_key in self.explored_regions:
            # We've seen this before
            recorded_confidence = self.explored_regions[topic_key].get('confidence', 0.5)
            uncertainty = 1.0 - recorded_confidence
        else:
            # Novel topic - high uncertainty
            uncertainty = 0.8
        
        # Adjust based on knowledge graph coverage
        if knowledge_graph:
            try:
                from sentence_transformers import SentenceTransformer
                model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
                query_emb = model.encode(query)
                
                # Find similar nodes
                max_similarity = 0.0
                for node in knowledge_graph.nodes():
                    if 'embedding' in knowledge_graph.nodes[node]:
                        node_emb = np.array(knowledge_graph.nodes[node]['embedding'])
                        sim = np.dot(query_emb, node_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(node_emb))
                        max_similarity = max(max_similarity, sim)
                
                # High similarity = low uncertainty
                graph_confidence = float(max_similarity)
                uncertainty *= (1.0 - graph_confidence * 0.5)
            except:
                pass
        
        return float(np.clip(uncertainty, 0.0, 1.0))
    
    def _estimate_novelty(self, query: str, context: Dict) -> float:
        """Estimate novelty of a topic"""
        topic_key = query.lower().strip()
        
        # Check exploration history
        if topic_key in self.explored_regions:
            visits = self.explored_regions[topic_key].get('visits', 0)
            novelty = 1.0 / (1.0 + np.log1p(visits))
        else:
            novelty = 1.0  # Completely new
        
        # Check semantic similarity to recent explorations
        if self.exploration_history:
            try:
                from sentence_transformers import SentenceTransformer
                model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
                query_emb = model.encode(query)
                
                recent_queries = [h['topic'] for h in list(self.exploration_history)[-20:]]
                if recent_queries:
                    recent_embs = model.encode(recent_queries)
                    similarities = [
                        np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb))
                        for emb in recent_embs
                    ]
                    max_recent_similarity = max(similarities)
                    novelty *= (1.0 - max_recent_similarity * 0.7)
            except:
                pass
        
        return float(np.clip(novelty, 0.0, 1.0))
    
    def _estimate_progress(self, query: str, context: Dict) -> float:
        """Estimate learning progress in related areas"""
        if len(self.learning_progress) < 3:
            return 0.5  # Unknown
        
        # Calculate recent learning curve slope
        recent_scores = list(self.learning_progress)[-10:]
        if len(recent_scores) < 2:
            return 0.5
        
        # Simple linear fit
        x = np.arange(len(recent_scores))
        y = np.array(recent_scores)
        
        if np.std(y) < 0.01:
            slope = 0.0
        else:
            slope = np.cov(x, y)[0, 1] / np.var(x) if np.var(x) > 0 else 0.0
        
        # Normalize to [0, 1]
        progress = 0.5 + np.tanh(slope * 2) * 0.5
        
        return float(np.clip(progress, 0.0, 1.0))
    
    def generate_exploration_goals(self,
                                  knowledge_graph: nx.DiGraph,
                                  recent_queries: List[str],
                                  max_goals: int = 5) -> List[Dict]:
        """
        Generate prioritized exploration goals
        
        Args:
            knowledge_graph: Current knowledge graph
            recent_queries: Recent user queries (for context)
            max_goals: Maximum goals to return
            
        Returns:
            List of exploration goals ranked by curiosity
        """
        # Reset daily counter if needed
        if datetime.now().date() > self.last_reset_date:
            self.today_explorations = 0
            self.last_reset_date = datetime.now().date()
        
        # Check budget
        if self.today_explorations >= self.exploration_budget:
            logger.info("Daily exploration budget exhausted")
            return []
        
        candidates = []
        
        # Strategy 1: Explore neighbors of high-uncertainty nodes
        if knowledge_graph and len(knowledge_graph) > 0:
            for node in list(knowledge_graph.nodes())[:20]:  # Sample first 20
                node_label = knowledge_graph.nodes[node].get('label', node)
                
                # Skip if recently explored
                if self._was_recently_explored(node_label, hours=24):
                    continue
                
                curiosity = self.compute_curiosity(
                    node_label,
                    context={},
                    knowledge_graph=knowledge_graph
                )
                
                if curiosity > 0.5:
                    neighbors = list(knowledge_graph.successors(node)) + list(knowledge_graph.predecessors(node))
                    unexplored = [
                        n for n in neighbors[:3]
                        if not self._was_recently_explored(
                            knowledge_graph.nodes[n].get('label', n),
                            hours=24
                        )
                    ]
                    
                    if unexplored:
                        goal = {
                            'type': 'explore_neighborhood',
                            'center_node': node_label,
                            'unexplored_neighbors': [
                                knowledge_graph.nodes[n].get('label', n)
                                for n in unexplored
                            ],
                            'curiosity': curiosity,
                            'reasoning': f"High uncertainty about {node_label}"
                        }
                        candidates.append(goal)
        
        # Strategy 2: Fill knowledge gaps (sparse regions)
        if knowledge_graph and len(knowledge_graph) > 10:
            degrees = dict(knowledge_graph.degree())
            sparse_nodes = sorted(degrees.items(), key=lambda x: x[1])[:5]
            
            for node, degree in sparse_nodes:
                if degree < 2:
                    node_label = knowledge_graph.nodes[node].get('label', node)
                    
                    if self._was_recently_explored(node_label, hours=48):
                        continue
                    
                    curiosity = self.compute_curiosity(
                        node_label,
                        context={},
                        knowledge_graph=knowledge_graph
                    )
                    
                    goal = {
                        'type': 'fill_gap',
                        'topic': node_label,
                        'current_connections': degree,
                        'curiosity': curiosity,
                        'reasoning': f"Sparse knowledge: only {degree} connections"
                    }
                    candidates.append(goal)
        
        # Strategy 3: Follow trends from recent queries
        if recent_queries:
            themes = self._extract_themes(recent_queries)
            
            for theme in themes:
                if self._was_recently_explored(theme, hours=12):
                    continue
                
                curiosity = self.compute_curiosity(
                    theme,
                    context={'recent': True},
                    knowledge_graph=knowledge_graph
                )
                
                goal = {
                    'type': 'follow_trend',
                    'theme': theme,
                    'related_queries': recent_queries[:3],
                    'curiosity': curiosity,
                    'reasoning': f"User interest trend: {theme}"
                }
                candidates.append(goal)
        
        # Rank by curiosity
        candidates.sort(key=lambda g: g['curiosity'], reverse=True)
        
        # Return top goals within budget
        remaining_budget = self.exploration_budget - self.today_explorations
        goals = candidates[:min(max_goals, remaining_budget)]
        
        logger.info(f"Generated {len(goals)} exploration goals")
        return goals
    
    def _was_recently_explored(self, topic: str, hours: int = 24) -> bool:
        """Check if topic was explored recently"""
        topic_key = topic.lower().strip()
        
        if topic_key not in self.explored_regions:
            return False
        
        last_visit = self.explored_regions[topic_key].get('last_visit')
        if not last_visit:
            return False
        
        try:
            last_visit_time = datetime.fromisoformat(last_visit)
            hours_since = (datetime.now() - last_visit_time).total_seconds() / 3600
            return hours_since < hours
        except:
            return False
    
    def _extract_themes(self, queries: List[str], max_themes: int = 3) -> List[str]:
        """Extract common themes from queries"""
        from collections import Counter
        import re
        
        text = ' '.join(queries).lower()
        
        # Filter stopwords
        stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
                    'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were'}
        
        words = re.findall(r'\b\w+\b', text)
        meaningful = [w for w in words if w not in stopwords and len(w) > 3]
        
        counter = Counter(meaningful)
        themes = [word for word, count in counter.most_common(max_themes)]
        
        return themes
    
    def record_exploration(self, 
                          topic: str,
                          outcome: str,
                          confidence: float,
                          learning_gain: float):
        """Record an exploration attempt and outcome"""
        topic_key = topic.lower().strip()
        
        # Update explored regions
        if topic_key not in self.explored_regions:
            self.explored_regions[topic_key] = {
                'visits': 0,
                'first_visit': datetime.now().isoformat(),
                'confidence': 0.5
            }
        
        self.explored_regions[topic_key]['visits'] += 1
        self.explored_regions[topic_key]['last_visit'] = datetime.now().isoformat()
        self.explored_regions[topic_key]['confidence'] = confidence
        
        # Add to history
        self.exploration_history.append({
            'timestamp': datetime.now().isoformat(),
            'topic': topic,
            'outcome': outcome,
            'confidence': confidence,
            'learning_gain': learning_gain
        })
        
        # Update learning progress
        self.learning_progress.append(learning_gain)
        
        # Increment daily counter
        self.today_explorations += 1
        
        # Save state
        self._save_state()
        
        logger.info(f"Exploration recorded: {topic} (gain: {learning_gain:.2f})")
    
    def _save_state(self):
        """Save exploration state to disk"""
        state_file = self.log_path / "curiosity_state.json"
        
        state = {
            'explored_regions': self.explored_regions,
            'last_updated': datetime.now().isoformat(),
            'total_explorations': len(self.exploration_history)
        }
        
        with open(state_file, 'w') as f:
            json.dump(state, f, indent=2)
    
    def get_metrics(self) -> Dict:
        """Get curiosity system metrics"""
        if not self.explored_regions:
            return {
                'total_explorations': 0,
                'unique_topics': 0,
                'avg_confidence': 0.5,
                'avg_learning_gain': 0.0,
                'exploration_budget_used': '0/100'
            }
        
        confidences = [r['confidence'] for r in self.explored_regions.values()]
        recent_gains = list(self.learning_progress)[-20:]
        
        return {
            'total_explorations': len(self.exploration_history),
            'unique_topics': len(self.explored_regions),
            'avg_confidence': float(np.mean(confidences)),
            'avg_learning_gain': float(np.mean(recent_gains)) if recent_gains else 0.0,
            'exploration_budget_used': f"{self.today_explorations}/{self.exploration_budget}",
            'most_explored': max(
                self.explored_regions.items(),
                key=lambda x: x[1]['visits']
            )[0] if self.explored_regions else None
        }
