"""
Hierarchical Planner
Decomposes complex goals into executable subgoal trees
"""

import networkx as nx
import numpy as np
from typing import List, Dict, Optional, Any, Callable
import yaml
import logging
import json
from pathlib import Path
from datetime import datetime
from enum import Enum

logger = logging.getLogger(__name__)

class GoalStatus(Enum):
    """Status of a goal in the hierarchy"""
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"
    BLOCKED = "blocked"

class Goal:
    """Represents a goal in the hierarchy"""
    def __init__(self,
                 name: str,
                 description: str,
                 level: int = 0,
                 parent: Optional[str] = None):
        self.name = name
        self.description = description
        self.level = level
        self.parent = parent
        self.status = GoalStatus.PENDING
        self.progress = 0.0
        self.created = datetime.now().isoformat()
        self.started = None
        self.completed = None
        self.attempts = 0
        self.metadata = {}
    
    def to_dict(self) -> Dict:
        return {
            'name': self.name,
            'description': self.description,
            'level': self.level,
            'parent': self.parent,
            'status': self.status.value,
            'progress': self.progress,
            'created': self.created,
            'started': self.started,
            'completed': self.completed,
            'attempts': self.attempts,
            'metadata': self.metadata
        }

class HierarchicalPlanner:
    """Plans and executes hierarchical goal structures"""
    
    def __init__(self, config_path: str = "/Eden/CONFIG/phi_fractal_config.yaml"):
        with open(config_path) as f:
            config = yaml.safe_load(f)
        
        if 'hierarchical_planner' not in config:
            config['hierarchical_planner'] = {
                'max_depth': 5,
                'max_children_per_goal': 7,
                'replan_threshold': 0.3,
                'progress_update_rate': 0.1
            }
        
        self.config = config['hierarchical_planner']
        self.max_depth = self.config['max_depth']
        self.max_children = self.config['max_children_per_goal']
        self.replan_threshold = self.config['replan_threshold']
        self.update_rate = self.config['progress_update_rate']
        
        self.goal_tree = nx.DiGraph()
        self.goals: Dict[str, Goal] = {}
        self.decomposition_strategies: Dict[str, Callable] = {}
        self.plan_history: List[Dict] = []
        
        self.plans_path = Path("/Eden/MEMORY/plans")
        self.plans_path.mkdir(parents=True, exist_ok=True)
        
        self._load_plans()
        
        logger.info("HierarchicalPlanner initialized")
    
    def _load_plans(self):
        """Load existing plans from disk"""
        current_plan = self.plans_path / "current_plan.json"
        
        if current_plan.exists():
            try:
                with open(current_plan) as f:
                    data = json.load(f)
                
                for goal_data in data.get('goals', []):
                    goal = Goal(
                        name=goal_data['name'],
                        description=goal_data['description'],
                        level=goal_data['level'],
                        parent=goal_data.get('parent')
                    )
                    goal.status = GoalStatus(goal_data['status'])
                    goal.progress = goal_data['progress']
                    self.goals[goal.name] = goal
                    self.goal_tree.add_node(goal.name, goal=goal)
                
                for edge in data.get('edges', []):
                    self.goal_tree.add_edge(edge[0], edge[1])
                
                logger.info(f"Loaded plan: {len(self.goals)} goals")
            except Exception as e:
                logger.error(f"Failed to load plans: {e}")
    
    def create_plan(self,
                   root_goal: str,
                   description: str,
                   context: Optional[Dict] = None) -> Dict:
        """Create hierarchical plan for a high-level goal"""
        if context is None:
            context = {}
        
        root = Goal(name=root_goal, description=description, level=0)
        
        self.goals[root_goal] = root
        self.goal_tree.add_node(root_goal, goal=root)
        
        self._decompose_goal(root_goal, context, depth=0)
        
        plan_record = {
            'timestamp': datetime.now().isoformat(),
            'root_goal': root_goal,
            'total_subgoals': len(self.goals),
            'max_depth': self._compute_depth(root_goal),
            'context': context
        }
        self.plan_history.append(plan_record)
        
        self._save_plans()
        
        logger.info(f"Plan created: {root_goal} with {len(self.goals)} goals")
        
        return {
            'root': root_goal,
            'total_goals': len(self.goals),
            'tree': self._get_tree_view(root_goal),
            'next_actions': self._get_next_actions()
        }
    
    def _decompose_goal(self, goal_name: str, context: Dict, depth: int):
        """Recursively decompose a goal into subgoals"""
        if depth >= self.max_depth:
            return
        
        goal = self.goals[goal_name]
        
        if self.goal_tree.out_degree(goal_name) > 0:
            return
        
        subgoals = self._default_decomposition(goal, context)
        
        if not subgoals or len(subgoals) > self.max_children:
            return
        
        for i, subgoal_desc in enumerate(subgoals):
            subgoal_name = f"{goal_name}_sub{i+1}"
            
            subgoal = Goal(
                name=subgoal_name,
                description=subgoal_desc,
                level=depth + 1,
                parent=goal_name
            )
            
            self.goals[subgoal_name] = subgoal
            self.goal_tree.add_node(subgoal_name, goal=subgoal)
            self.goal_tree.add_edge(goal_name, subgoal_name)
            
            self._decompose_goal(subgoal_name, context, depth + 1)
    
    def _default_decomposition(self, goal: Goal, context: Dict) -> List[str]:
        """Default decomposition strategy"""
        description = goal.description.lower()
        
        # Leaf goal patterns - don't decompose these further
        leaf_patterns = [
            'analyze', 'research', 'gather', 'practice', 'experiment',
            'validate', 'design', 'plan', 'implement', 'test', 'refine',
            'measure', 'identify', 'apply', 'verify', 'execute'
        ]
        
        # Check if this is already a leaf goal
        if any(pattern in description for pattern in leaf_patterns):
            if goal.level >= 2:  # Stop after level 2
                return []  # No further decomposition
        
        # Don't decompose beyond depth 3
        if goal.level >= 3:
            return []
        
        if "learn" in description and goal.level < 2:
            return [
                "Research and gather information",
                "Practice and experiment", 
                "Validate understanding"
            ]
        elif "build" in description and goal.level < 2:
            return [
                "Design and plan",
                "Implement core functionality",
                "Test and refine"
            ]
        elif "optimize" in description and goal.level < 2:
            return [
                "Measure current performance",
                "Identify bottlenecks",
                "Apply improvements"
            ]
        else:
            # Generic decomposition only at level 0
            if goal.level == 0:
                return [
                    "Analyze requirements",
                    "Execute main task",
                    "Verify completion"
                ]
            else:
                return []  # Don't decompose further
    
    def _compute_depth(self, goal_name: str) -> int:
        """Compute depth of goal tree"""
        if self.goal_tree.out_degree(goal_name) == 0:
            return 0
        
        children = self.goal_tree.successors(goal_name)
        return 1 + max(self._compute_depth(child) for child in children)
    
    def _get_tree_view(self, root: str, indent: int = 0) -> str:
        """Get textual representation of goal tree"""
        goal = self.goals[root]
        status_symbol = {
            GoalStatus.PENDING: "○",
            GoalStatus.IN_PROGRESS: "◐",
            GoalStatus.COMPLETED: "●",
            GoalStatus.FAILED: "✗",
            GoalStatus.BLOCKED: "⊗"
        }
        
        symbol = status_symbol.get(goal.status, "?")
        line = "  " * indent + f"{symbol} {goal.description} ({goal.progress:.0%})\n"
        
        children = list(self.goal_tree.successors(root))
        for child in children:
            line += self._get_tree_view(child, indent + 1)
        
        return line
    
    def _get_next_actions(self) -> List[Dict]:
        """Get list of executable actions"""
        next_actions = []
        
        for goal_name, goal in self.goals.items():
            # Leaf goal (no children)
            if self.goal_tree.out_degree(goal_name) == 0:
                if goal.status == GoalStatus.PENDING:
                    parent_ready = True
                    if goal.parent:
                        parent = self.goals[goal.parent]
                        # Parent must be IN_PROGRESS or PENDING (for root's children)
                        parent_ready = parent.status in [GoalStatus.IN_PROGRESS, GoalStatus.PENDING]
                    
                    if parent_ready:
                        next_actions.append({
                            'goal': goal_name,
                            'description': goal.description,
                            'level': goal.level,
                            'priority': self._compute_priority(goal)
                        })
        
        next_actions.sort(key=lambda a: a['priority'], reverse=True)
        return next_actions
    
    def _compute_priority(self, goal: Goal) -> float:
        """Compute priority of a goal"""
        level_priority = (self.max_depth - goal.level) / self.max_depth
        descendants = len(nx.descendants(self.goal_tree, goal.name)) if self.goal_tree.has_node(goal.name) else 0
        blocking_priority = descendants / max(len(self.goals), 1)
        
        return 0.7 * level_priority + 0.3 * blocking_priority
    
    def start_goal(self, goal_name: str):
        """Mark a goal as started"""
        if goal_name in self.goals:
            goal = self.goals[goal_name]
            goal.status = GoalStatus.IN_PROGRESS
            goal.started = datetime.now().isoformat()
            goal.attempts += 1
            
            # Auto-start parent hierarchy
            if goal.parent and goal.parent in self.goals:
                parent = self.goals[goal.parent]
                if parent.status == GoalStatus.PENDING:
                    parent.status = GoalStatus.IN_PROGRESS
                    parent.started = datetime.now().isoformat()
            
            self._save_plans()
            logger.info(f"Goal started: {goal_name}")
    
    def update_progress(self, goal_name: str, progress: float):
        """Update progress of a goal"""
        if goal_name in self.goals:
            goal = self.goals[goal_name]
            goal.progress = max(0.0, min(1.0, progress))
            self._propagate_progress(goal_name)
            self._save_plans()
    
    def _propagate_progress(self, goal_name: str):
        """Propagate progress updates up the tree"""
        goal = self.goals[goal_name]
        
        if goal.parent and goal.parent in self.goals:
            parent = self.goals[goal.parent]
            children = list(self.goal_tree.successors(goal.parent))
            if children:
                avg_progress = sum(self.goals[c].progress for c in children) / len(children)
                parent.progress = avg_progress
                self._propagate_progress(goal.parent)
    
    def complete_goal(self, goal_name: str, success: bool = True):
        """Mark a goal as completed or failed"""
        if goal_name not in self.goals:
            return
        
        goal = self.goals[goal_name]
        
        if success:
            goal.status = GoalStatus.COMPLETED
            goal.progress = 1.0
            goal.completed = datetime.now().isoformat()
            logger.info(f"Goal completed: {goal_name}")
        else:
            goal.status = GoalStatus.FAILED
            logger.warning(f"Goal failed: {goal_name}")
        
        if goal.parent:
            self._check_parent_completion(goal.parent)
        
        self._save_plans()
    
    def _check_parent_completion(self, parent_name: str):
        """Check if all children are complete"""
        children = list(self.goal_tree.successors(parent_name))
        
        if not children:
            return
        
        all_complete = all(
            self.goals[c].status == GoalStatus.COMPLETED
            for c in children
        )
        
        if all_complete:
            self.complete_goal(parent_name, success=True)
    
    def _save_plans(self):
        """Save current plan to disk"""
        current_plan = self.plans_path / "current_plan.json"
        
        data = {
            'goals': [goal.to_dict() for goal in self.goals.values()],
            'edges': list(self.goal_tree.edges()),
            'last_updated': datetime.now().isoformat()
        }
        
        with open(current_plan, 'w') as f:
            json.dump(data, f, indent=2)
    
    def get_metrics(self) -> Dict:
        """Get planning metrics"""
        if not self.goals:
            return {
                'total_goals': 0,
                'completed': 0,
                'in_progress': 0,
                'failed': 0,
                'completion_rate': 0.0,
                'avg_depth': 0
            }
        
        status_counts = {
            'completed': sum(1 for g in self.goals.values() if g.status == GoalStatus.COMPLETED),
            'in_progress': sum(1 for g in self.goals.values() if g.status == GoalStatus.IN_PROGRESS),
            'failed': sum(1 for g in self.goals.values() if g.status == GoalStatus.FAILED),
            'pending': sum(1 for g in self.goals.values() if g.status == GoalStatus.PENDING)
        }
        
        roots = [g for g in self.goals if self.goals[g].parent is None]
        avg_depth = sum(self._compute_depth(r) for r in roots) / len(roots) if roots else 0
        
        completion_rate = status_counts['completed'] / len(self.goals) if self.goals else 0.0
        
        return {
            'total_goals': len(self.goals),
            'completed': status_counts['completed'],
            'in_progress': status_counts['in_progress'],
            'failed': status_counts['failed'],
            'pending': status_counts['pending'],
            'completion_rate': completion_rate,
            'avg_depth': avg_depth
        }
