"""
A* Pathfinding - Optimal path search with heuristics
Used in: Games, robotics, navigation, logistics
"""
import heapq
import numpy as np
from typing import Tuple, List, Set, Dict, Optional
import time

class Node:
    """A* search node"""
    
    def __init__(self, position: Tuple[int, int], g: float, h: float, parent=None):
        self.position = position
        self.g = g  # Cost from start
        self.h = h  # Heuristic to goal
        self.f = g + h  # Total cost
        self.parent = parent
    
    def __lt__(self, other):
        return self.f < other.f

def manhattan_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
    """Manhattan distance heuristic"""
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])

def euclidean_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> float:
    """Euclidean distance heuristic"""
    return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)

def astar(
    grid: np.ndarray,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    heuristic='manhattan'
) -> Optional[List[Tuple[int, int]]]:
    """
    A* pathfinding algorithm
    
    Args:
        grid: 2D array where 0=free, 1=obstacle
        start: Starting position
        goal: Goal position
        heuristic: 'manhattan' or 'euclidean'
    
    Returns:
        List of positions from start to goal, or None if no path
    """
    h_func = manhattan_distance if heuristic == 'manhattan' else euclidean_distance
    
    # Priority queue
    open_set = []
    start_node = Node(start, 0, h_func(start, goal))
    heapq.heappush(open_set, start_node)
    
    # Visited nodes
    closed_set: Set[Tuple[int, int]] = set()
    
    # Best g-score for each position
    g_scores: Dict[Tuple[int, int], float] = {start: 0}
    
    rows, cols = grid.shape
    
    nodes_explored = 0
    
    while open_set:
        current = heapq.heappop(open_set)
        nodes_explored += 1
        
        # Reached goal
        if current.position == goal:
            path = []
            node = current
            while node:
                path.append(node.position)
                node = node.parent
            return list(reversed(path)), nodes_explored
        
        closed_set.add(current.position)
        
        # Explore neighbors (4-directional)
        for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            neighbor_pos = (current.position[0] + dx, current.position[1] + dy)
            
            # Check bounds
            if not (0 <= neighbor_pos[0] < rows and 0 <= neighbor_pos[1] < cols):
                continue
            
            # Check obstacle
            if grid[neighbor_pos] == 1:
                continue
            
            # Already visited
            if neighbor_pos in closed_set:
                continue
            
            # Calculate g-score
            tentative_g = current.g + 1
            
            # Skip if not better path
            if neighbor_pos in g_scores and tentative_g >= g_scores[neighbor_pos]:
                continue
            
            # Record better path
            g_scores[neighbor_pos] = tentative_g
            h = h_func(neighbor_pos, goal)
            neighbor_node = Node(neighbor_pos, tentative_g, h, current)
            heapq.heappush(open_set, neighbor_node)
    
    return None, nodes_explored  # No path found

def create_maze(size: int = 50, obstacle_prob: float = 0.2) -> np.ndarray:
    """Create random maze"""
    grid = (np.random.rand(size, size) < obstacle_prob).astype(int)
    grid[0, 0] = 0  # Ensure start is free
    grid[-1, -1] = 0  # Ensure goal is free
    return grid

def visualize_path(grid: np.ndarray, path: List[Tuple[int, int]]):
    """Simple text visualization"""
    display = grid.copy().astype(str)
    display[display == '0'] = '.'
    display[display == '1'] = '#'
    
    if path:
        for pos in path:
            display[pos] = '*'
        display[path[0]] = 'S'
        display[path[-1]] = 'G'
    
    return '\n'.join([''.join(row) for row in display])

if __name__ == "__main__":
    print("\n" + "="*70)
    print("A* PATHFINDING ALGORITHM")
    print("="*70)
    
    # Test 1: Simple maze
    print("\nTest 1: Small maze (10x10)")
    grid = create_maze(10, 0.2)
    start = (0, 0)
    goal = (9, 9)
    
    start_time = time.time()
    result = astar(grid, start, goal, heuristic='manhattan')
    elapsed = time.time() - start_time
    
    if result[0]:
        path, nodes = result
        print(f"✅ Path found! Length: {len(path)} steps")
        print(f"   Nodes explored: {nodes}")
        print(f"   Time: {elapsed*1000:.2f}ms")
        print("\nPath visualization:")
        print(visualize_path(grid, path))
    else:
        print("❌ No path found")
    
    # Test 2: Large maze benchmark
    print("\n" + "="*70)
    print("Test 2: Large maze benchmark (50x50)")
    print("="*70)
    
    grid = create_maze(50, 0.2)
    start = (0, 0)
    goal = (49, 49)
    
    # Manhattan heuristic
    start_time = time.time()
    path_m, nodes_m = astar(grid, start, goal, heuristic='manhattan')
    time_m = time.time() - start_time
    
    # Euclidean heuristic
    start_time = time.time()
    path_e, nodes_e = astar(grid, start, goal, heuristic='euclidean')
    time_e = time.time() - start_time
    
    print(f"\nManhattan heuristic:")
    print(f"  Path length: {len(path_m) if path_m else 'No path'}")
    print(f"  Nodes explored: {nodes_m}")
    print(f"  Time: {time_m*1000:.2f}ms")
    
    print(f"\nEuclidean heuristic:")
    print(f"  Path length: {len(path_e) if path_e else 'No path'}")
    print(f"  Nodes explored: {nodes_e}")
    print(f"  Time: {time_e*1000:.2f}ms")
    
    print("\n✅ A* pathfinding implementation complete!")
    print("   Used in: Game AI, robotics, navigation systems")
