"""
Monte Carlo Tree Search (MCTS)
Used in: AlphaGo, game playing, decision making under uncertainty
"""
import numpy as np
import math
from typing import List, Optional
import time

class TicTacToeState:
    """Simple Tic-Tac-Toe for MCTS demo"""
    
    def __init__(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
    
    def get_legal_moves(self) -> List[tuple]:
        """Get available positions"""
        moves = []
        for i in range(3):
            for j in range(3):
                if self.board[i, j] == 0:
                    moves.append((i, j))
        return moves
    
    def make_move(self, move: tuple):
        """Apply move and switch player"""
        new_state = TicTacToeState()
        new_state.board = self.board.copy()
        new_state.board[move] = self.current_player
        new_state.current_player = 3 - self.current_player  # Switch 1<->2
        return new_state
    
    def is_terminal(self) -> bool:
        """Check if game is over"""
        return self.get_winner() is not None or len(self.get_legal_moves()) == 0
    
    def get_winner(self) -> Optional[int]:
        """Return winner (1 or 2) or None"""
        # Check rows, columns, diagonals
        for player in [1, 2]:
            # Rows
            for i in range(3):
                if all(self.board[i, :] == player):
                    return player
            # Columns
            for j in range(3):
                if all(self.board[:, j] == player):
                    return player
            # Diagonals
            if all(self.board.diagonal() == player):
                return player
            if all(np.fliplr(self.board).diagonal() == player):
                return player
        return None
    
    def get_reward(self, player: int) -> float:
        """Get reward from player's perspective"""
        winner = self.get_winner()
        if winner == player:
            return 1.0
        elif winner is not None:
            return -1.0
        else:
            return 0.0

class MCTSNode:
    """Node in MCTS tree"""
    
    def __init__(self, state: TicTacToeState, parent=None, move=None):
        self.state = state
        self.parent = parent
        self.move = move  # Move that led to this state
        
        self.children = []
        self.untried_moves = state.get_legal_moves()
        
        self.visits = 0
        self.wins = 0.0
    
    def uct_value(self, c: float = 1.41) -> float:
        """Upper Confidence Bound for Trees"""
        if self.visits == 0:
            return float('inf')
        
        exploitation = self.wins / self.visits
        exploration = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        
        return exploitation + exploration
    
    def select_child(self):
        """Select child with highest UCT value"""
        return max(self.children, key=lambda n: n.uct_value())
    
    def expand(self):
        """Expand node by adding one child"""
        move = self.untried_moves.pop()
        new_state = self.state.make_move(move)
        child = MCTSNode(new_state, parent=self, move=move)
        self.children.append(child)
        return child
    
    def rollout(self) -> float:
        """Simulate random game to terminal state"""
        state = self.state
        
        while not state.is_terminal():
            moves = state.get_legal_moves()
            move = moves[np.random.randint(len(moves))]
            state = state.make_move(move)
        
        # Return reward from original player's perspective
        return state.get_reward(self.state.current_player)
    
    def backpropagate(self, reward: float):
        """Update node and ancestors"""
        self.visits += 1
        self.wins += reward
        
        if self.parent:
            self.parent.backpropagate(-reward)  # Flip reward for opponent

def mcts(root_state: TicTacToeState, iterations: int = 1000) -> tuple:
    """
    Monte Carlo Tree Search
    
    Returns: (best_move, win_rate)
    """
    root = MCTSNode(root_state)
    
    for _ in range(iterations):
        node = root
        
        # Selection: traverse tree using UCT
        while not node.state.is_terminal() and len(node.untried_moves) == 0:
            node = node.select_child()
        
        # Expansion: add new child if possible
        if len(node.untried_moves) > 0:
            node = node.expand()
        
        # Simulation: rollout from new node
        reward = node.rollout()
        
        # Backpropagation: update tree
        node.backpropagate(reward)
    
    # Return move with most visits (most explored)
    if not root.children:
        return None, 0.0
    
    best_child = max(root.children, key=lambda n: n.visits)
    win_rate = best_child.wins / best_child.visits if best_child.visits > 0 else 0
    
    return best_child.move, win_rate

def play_game():
    """Play full game with MCTS vs random"""
    state = TicTacToeState()
    moves_played = []
    
    print("\n" + "="*70)
    print("MCTS (Player 1) vs Random (Player 2)")
    print("="*70)
    
    while not state.is_terminal():
        if state.current_player == 1:
            # MCTS move
            move, win_rate = mcts(state, iterations=1000)
            print(f"\nMCTS move: {move} (win rate: {win_rate*100:.1f}%)")
        else:
            # Random move
            moves = state.get_legal_moves()
            move = moves[np.random.randint(len(moves))]
            print(f"\nRandom move: {move}")
        
        moves_played.append(move)
        state = state.make_move(move)
        
        # Display board
        display = state.board.copy().astype(str)
        display[display == '0'] = '.'
        display[display == '1'] = 'X'
        display[display == '2'] = 'O'
        print('\n'.join([''.join(row) for row in display]))
    
    winner = state.get_winner()
    if winner == 1:
        print("\n🏆 MCTS (X) wins!")
    elif winner == 2:
        print("\n🎲 Random (O) wins!")
    else:
        print("\n🤝 Draw!")
    
    return winner

if __name__ == "__main__":
    print("\n" + "="*70)
    print("MONTE CARLO TREE SEARCH (MCTS)")
    print("="*70)
    
    # Test 1: Single game
    play_game()
    
    # Test 2: Benchmark - MCTS vs Random over many games
    print("\n" + "="*70)
    print("BENCHMARK: 20 games MCTS vs Random")
    print("="*70)
    
    wins = {1: 0, 2: 0, 'draw': 0}
    
    for game in range(20):
        state = TicTacToeState()
        
        while not state.is_terminal():
            if state.current_player == 1:
                move, _ = mcts(state, iterations=500)
            else:
                moves = state.get_legal_moves()
                move = moves[np.random.randint(len(moves))]
            
            state = state.make_move(move)
        
        winner = state.get_winner()
        if winner:
            wins[winner] += 1
        else:
            wins['draw'] += 1
    
    print(f"\nResults over 20 games:")
    print(f"  MCTS wins: {wins[1]}")
    print(f"  Random wins: {wins[2]}")
    print(f"  Draws: {wins['draw']}")
    print(f"  MCTS win rate: {wins[1]/20*100:.1f}%")
    
    if wins[1] >= 15:
        print(f"\n🎯 GOAL ACHIEVED! MCTS dominates random play")
    
    print("\n✅ MCTS implementation complete!")
    print("   Used in: AlphaGo, AlphaZero, game AI, planning under uncertainty")
