#!/usr/bin/env python3
"""
Planning System for Eden
Solve: Tower of Hanoi, Blocksworld, Pathfinding
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import heapq
import numpy as np
from tqdm import tqdm
import random
import itertools

# =============================================================================
# PLANNING PROBLEMS
# =============================================================================

class TowerOfHanoi:
    """Classic Tower of Hanoi problem"""
    
    def __init__(self, n_disks=3):
        self.n_disks = n_disks
        self.initial = (tuple(range(n_disks, 0, -1)), (), ())
        self.goal = ((), (), tuple(range(n_disks, 0, -1)))
    
    def get_actions(self, state):
        """Get valid moves from current state"""
        actions = []
        for from_peg in range(3):
            if not state[from_peg]:
                continue
            for to_peg in range(3):
                if from_peg == to_peg:
                    continue
                if not state[to_peg] or state[from_peg][-1] < state[to_peg][-1]:
                    actions.append((from_peg, to_peg))
        return actions
    
    def apply_action(self, state, action):
        """Apply action to state"""
        from_peg, to_peg = action
        new_state = [list(peg) for peg in state]
        disk = new_state[from_peg].pop()
        new_state[to_peg].append(disk)
        return tuple(tuple(peg) for peg in new_state)
    
    def is_goal(self, state):
        return state == self.goal
    
    def heuristic(self, state):
        """How many disks not on goal peg?"""
        return self.n_disks - len(state[2])

class Blocksworld:
    """Blocksworld planning problem"""
    
    def __init__(self, initial_state, goal_state):
        self.initial = initial_state
        self.goal = goal_state
    
    def get_actions(self, state):
        """Get valid block moves"""
        actions = []
        
        # Find clear blocks (nothing on top)
        on_blocks = set(state.values()) - {'table'}
        clear_blocks = set(state.keys()) - on_blocks
        
        for block in clear_blocks:
            # Can move to table
            if state[block] != 'table':
                actions.append((block, 'table'))
            
            # Can move onto other clear blocks
            for target in clear_blocks:
                if target != block and target != state[block]:
                    actions.append((block, target))
        
        return actions
    
    def apply_action(self, state, action):
        """Move block"""
        block, destination = action
        new_state = state.copy()
        new_state[block] = destination
        return new_state
    
    def is_goal(self, state):
        return all(state.get(block) == self.goal.get(block) for block in self.goal)
    
    def heuristic(self, state):
        """Count blocks not in goal position"""
        return sum(1 for block in self.goal if state.get(block) != self.goal.get(block))

class GridPathfinding:
    """2D grid pathfinding"""
    
    def __init__(self, grid, start, goal):
        self.grid = grid
        self.start = start
        self.goal = goal
        self.height = len(grid)
        self.width = len(grid[0])
    
    def get_actions(self, state):
        """Get valid moves (up, down, left, right)"""
        x, y = state
        actions = []
        
        for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
            nx, ny = x + dx, y + dy
            if 0 <= nx < self.height and 0 <= ny < self.width:
                if self.grid[nx][ny] == 0:
                    actions.append((nx, ny))
        
        return actions
    
    def apply_action(self, state, action):
        return action
    
    def is_goal(self, state):
        return state == self.goal
    
    def heuristic(self, state):
        """Manhattan distance"""
        return abs(state[0] - self.goal[0]) + abs(state[1] - self.goal[1])

# =============================================================================
# A* SEARCH PLANNER
# =============================================================================

class AStarPlanner:
    """A* search for optimal planning"""
    
    def __init__(self, problem, max_steps=10000):
        self.problem = problem
        self.max_steps = max_steps
        self.counter = itertools.count()  # Unique tiebreaker
    
    def search(self):
        """Run A* search"""
        start = self.problem.initial if hasattr(self.problem, 'initial') else self.problem.start
        
        # Priority queue: (f_score, unique_id, state, g_score, path)
        open_set = [(self.problem.heuristic(start), next(self.counter), start, 0, [])]
        closed_set = set()
        steps = 0
        
        while open_set and steps < self.max_steps:
            steps += 1
            _, _, current, g_score, path = heapq.heappop(open_set)
            
            if self.problem.is_goal(current):
                return path, steps
            
            # Convert state to hashable form
            if isinstance(current, dict):
                state_hash = tuple(sorted(current.items()))
            else:
                state_hash = str(current)
            
            if state_hash in closed_set:
                continue
            closed_set.add(state_hash)
            
            for action in self.problem.get_actions(current):
                new_state = self.problem.apply_action(current, action)
                new_g = g_score + 1
                new_f = new_g + self.problem.heuristic(new_state)
                new_path = path + [action]
                
                heapq.heappush(open_set, (new_f, next(self.counter), new_state, new_g, new_path))
        
        return None, steps

# =============================================================================
# TESTING
# =============================================================================

def test_planning():
    """Test planning on classic problems"""
    print("\n" + "="*70)
    print("TESTING PLANNING CAPABILITIES")
    print("="*70)
    
    results = []
    
    # Test 1: Tower of Hanoi (3 disks)
    print("\n" + "="*70)
    print("TEST 1: Tower of Hanoi (3 disks)")
    print("="*70)
    
    problem = TowerOfHanoi(3)
    planner = AStarPlanner(problem)
    solution, steps = planner.search()
    
    print(f"Initial: {problem.initial}")
    print(f"Goal: {problem.goal}")
    
    if solution:
        print(f"✅ SOLVED in {len(solution)} moves ({steps} states explored)")
        print(f"Optimal: {2**3 - 1} = 7 moves")
        results.append(len(solution) == 7)
    else:
        print("❌ FAILED")
        results.append(False)
    
    # Test 2: Tower of Hanoi (4 disks)
    print("\n" + "="*70)
    print("TEST 2: Tower of Hanoi (4 disks)")
    print("="*70)
    
    problem = TowerOfHanoi(4)
    planner = AStarPlanner(problem)
    solution, steps = planner.search()
    
    if solution:
        print(f"✅ SOLVED in {len(solution)} moves ({steps} states explored)")
        print(f"Optimal: {2**4 - 1} = 15 moves")
        results.append(len(solution) == 15)
    else:
        print("❌ FAILED")
        results.append(False)
    
    # Test 3: Blocksworld
    print("\n" + "="*70)
    print("TEST 3: Blocksworld")
    print("="*70)
    
    initial = {'A': 'table', 'B': 'A', 'C': 'table'}
    goal = {'A': 'B', 'B': 'table', 'C': 'A'}
    
    print(f"Initial: {initial}")
    print(f"Goal: {goal}")
    
    problem = Blocksworld(initial, goal)
    planner = AStarPlanner(problem)
    solution, steps = planner.search()
    
    if solution:
        print(f"✅ SOLVED in {len(solution)} moves")
        print(f"Solution: {solution}")
        results.append(True)
    else:
        print("❌ FAILED")
        results.append(False)
    
    # Test 4: Pathfinding
    print("\n" + "="*70)
    print("TEST 4: Grid Pathfinding")
    print("="*70)
    
    grid = np.array([
        [0, 0, 0, 1, 0],
        [0, 1, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0]
    ])
    
    print("Grid (0=free, 1=blocked):")
    print(grid)
    
    problem = GridPathfinding(grid, (0, 0), (4, 4))
    planner = AStarPlanner(problem)
    solution, steps = planner.search()
    
    if solution:
        print(f"✅ SOLVED - Path length: {len(solution)}")
        results.append(True)
    else:
        print("❌ FAILED")
        results.append(False)
    
    # Summary
    print("\n" + "="*70)
    print("RESULTS SUMMARY")
    print("="*70)
    
    passed = sum(results)
    total = len(results)
    
    print(f"\nScore: {passed}/{total} ({100*passed//total}%)")
    
    if passed == total:
        print("\n✅ PERFECT! All planning tests passed!")
        print("\nEden can now:")
        print("  1. Solve Tower of Hanoi optimally")
        print("  2. Plan in Blocksworld")
        print("  3. Find optimal paths in grids")
        print("  4. Use A* search for goal-directed planning")
    elif passed >= total * 0.75:
        print("\n✅ SUCCESS! Most planning tests passed!")
    else:
        print("\n⚠️ Needs improvement")
    
    return passed == total

def main():
    success = test_planning()
    
    print("\n" + "="*70)
    print("PLANNING SYSTEM STATUS")
    print("="*70)
    
    if success:
        print("\n✅ Planning capability: WORKING")
        print("\nA* search implemented for:")
        print("  - Tower of Hanoi (optimal)")
        print("  - Blocksworld (blocks manipulation)")
        print("  - Grid pathfinding (navigation)")
        print("\n✅ CAPABILITY #4 COMPLETE")

if __name__ == "__main__":
    main()
