"""
Q-Learning - Model-free reinforcement learning
Learn optimal policies through trial and error
"""
import numpy as np
import time

class GridWorld:
    """Simple grid environment for RL"""
    
    def __init__(self, size=5):
        self.size = size
        self.start = (0, 0)
        self.goal = (size-1, size-1)
        self.obstacles = [(1, 1), (2, 2), (3, 1)]
        self.reset()
    
    def reset(self):
        self.pos = self.start
        return self.pos
    
    def step(self, action):
        """Actions: 0=up, 1=down, 2=left, 3=right"""
        moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        new_pos = (self.pos[0] + moves[action][0], 
                   self.pos[1] + moves[action][1])
        
        # Check bounds and obstacles
        if (0 <= new_pos[0] < self.size and 
            0 <= new_pos[1] < self.size and 
            new_pos not in self.obstacles):
            self.pos = new_pos
        
        # Rewards
        if self.pos == self.goal:
            return self.pos, 10, True
        else:
            return self.pos, -0.1, False

class QLearner:
    """Q-Learning agent"""
    
    def __init__(self, n_states, n_actions, lr=0.1, gamma=0.95, epsilon=0.1):
        self.q_table = np.zeros((n_states, n_actions))
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.n_actions = n_actions
    
    def get_action(self, state_idx, training=True):
        """Epsilon-greedy action selection"""
        if training and np.random.rand() < self.epsilon:
            return np.random.randint(self.n_actions)
        return np.argmax(self.q_table[state_idx])
    
    def update(self, state_idx, action, reward, next_state_idx):
        """Q-learning update"""
        best_next = np.max(self.q_table[next_state_idx])
        td_target = reward + self.gamma * best_next
        td_error = td_target - self.q_table[state_idx, action]
        self.q_table[state_idx, action] += self.lr * td_error

def train_q_learning(episodes=1000):
    """Train Q-learning agent"""
    print("\n" + "="*70)
    print("Q-LEARNING REINFORCEMENT LEARNING")
    print(f"Training for {episodes} episodes")
    print("="*70)
    
    env = GridWorld(size=5)
    agent = QLearner(n_states=25, n_actions=4)
    
    rewards_history = []
    
    for ep in range(episodes):
        state = env.reset()
        state_idx = state[0] * env.size + state[1]
        total_reward = 0
        steps = 0
        
        for _ in range(100):
            action = agent.get_action(state_idx)
            next_state, reward, done = env.step(action)
            next_state_idx = next_state[0] * env.size + next_state[1]
            
            agent.update(state_idx, action, reward, next_state_idx)
            
            state_idx = next_state_idx
            total_reward += reward
            steps += 1
            
            if done:
                break
        
        rewards_history.append(total_reward)
        
        if (ep + 1) % 200 == 0:
            avg_reward = np.mean(rewards_history[-100:])
            print(f"Episode {ep+1}/{episodes} - Avg reward (last 100): {avg_reward:.2f}")
    
    # Test learned policy
    print("\n" + "="*70)
    print("TESTING LEARNED POLICY")
    print("="*70)
    
    test_rewards = []
    for _ in range(100):
        state = env.reset()
        state_idx = state[0] * env.size + state[1]
        total_reward = 0
        
        for _ in range(50):
            action = agent.get_action(state_idx, training=False)
            next_state, reward, done = env.step(action)
            next_state_idx = next_state[0] * env.size + next_state[1]
            
            state_idx = next_state_idx
            total_reward += reward
            
            if done:
                break
        
        test_rewards.append(total_reward)
    
    avg_test = np.mean(test_rewards)
    success_rate = sum(1 for r in test_rewards if r > 5) / len(test_rewards)
    
    print(f"\nTest Results (100 episodes):")
    print(f"  Average reward: {avg_test:.2f}")
    print(f"  Success rate: {success_rate*100:.1f}%")
    
    if success_rate >= 0.8:
        print(f"\n🎯 GOAL ACHIEVED! >80% success rate")
    
    return avg_test, success_rate

if __name__ == "__main__":
    avg_reward, success = train_q_learning(episodes=1000)
    print("\n✅ Q-Learning implementation complete!")
    print(f"   Success rate: {success*100:.1f}%")
