#!/usr/bin/env python3
"""
EDEN REWARD SIGNAL
==================
Simple RL-style reward tracking that shapes Eden's behavior over time.
Not PPO — that needs gradient access. This is reward-weighted selection:
actions that led to rewards get higher priority next time.

Integrated into: AGI Loop, GWT modules, OMEGA evolution
"""
import sqlite3
import json
from datetime import datetime, timezone
from collections import defaultdict

PHI = 1.618033988749895
DB_PATH = "/Eden/DATA/reward_signal.db"


class RewardSignal:
    def __init__(self, db_path=DB_PATH):
        self.db_path = db_path
        self.episode_rewards = []  # Current episode buffer
        self._init_db()
    
    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        conn.executescript('''
            CREATE TABLE IF NOT EXISTS rewards (
                id INTEGER PRIMARY KEY,
                timestamp TEXT,
                source TEXT,
                action TEXT,
                reward REAL,
                context TEXT
            );
            CREATE TABLE IF NOT EXISTS action_values (
                action_type TEXT PRIMARY KEY,
                cumulative_reward REAL DEFAULT 0,
                count INTEGER DEFAULT 0,
                avg_reward REAL DEFAULT 0,
                last_updated TEXT
            );
            CREATE TABLE IF NOT EXISTS reward_history (
                id INTEGER PRIMARY KEY,
                timestamp TEXT,
                window_avg REAL,
                window_size INTEGER
            );
        ''')
        conn.commit()
        conn.close()
    
    def reward(self, source: str, action: str, value: float, context: str = ""):
        """
        Issue a reward signal.
        value: -1.0 (terrible) to +1.0 (excellent)
        source: what generated this (agi_loop, omega, daddy, etc)
        """
        # Clamp
        value = max(-1.0, min(1.0, value))
        
        # φ-weight: rewards from Daddy are worth more
        if source == "daddy":
            value *= PHI
        
        conn = sqlite3.connect(self.db_path)
        conn.execute(
            "INSERT INTO rewards (timestamp, source, action, reward, context) VALUES (?,?,?,?,?)",
            (datetime.now(timezone.utc).isoformat(), source, action[:200], value, context[:500])
        )
        
        # Update running action values
        conn.execute('''
            INSERT INTO action_values (action_type, cumulative_reward, count, avg_reward, last_updated)
            VALUES (?, ?, 1, ?, ?)
            ON CONFLICT(action_type) DO UPDATE SET
                cumulative_reward = cumulative_reward + ?,
                count = count + 1,
                avg_reward = (cumulative_reward + ?) / (count + 1),
                last_updated = ?
        ''', (action[:100], value, value, datetime.now().isoformat(),
              value, value, datetime.now().isoformat()))
        
        conn.commit()
        conn.close()
        
        self.episode_rewards.append(value)
    
    def get_action_value(self, action_type: str) -> float:
        """Get learned value of an action type"""
        try:
            conn = sqlite3.connect(self.db_path)
            row = conn.execute(
                "SELECT avg_reward FROM action_values WHERE action_type=?",
                (action_type[:100],)
            ).fetchone()
            conn.close()
            return row[0] if row else 0.0
        except:
            return 0.0
    
    def get_top_actions(self, n=10) -> list:
        """Return highest-value actions"""
        try:
            conn = sqlite3.connect(self.db_path)
            rows = conn.execute(
                "SELECT action_type, avg_reward, count FROM action_values ORDER BY avg_reward DESC LIMIT ?",
                (n,)
            ).fetchall()
            conn.close()
            return [(r[0], r[1], r[2]) for r in rows]
        except:
            return []
    
    def get_recent_avg(self, window=50) -> float:
        """Average reward over last N signals"""
        try:
            conn = sqlite3.connect(self.db_path)
            rows = conn.execute(
                "SELECT reward FROM rewards ORDER BY id DESC LIMIT ?", (window,)
            ).fetchall()
            conn.close()
            if rows:
                return sum(r[0] for r in rows) / len(rows)
            return 0.0
        except:
            return 0.0
    
    def get_total_stats(self) -> dict:
        """Overall reward statistics"""
        try:
            conn = sqlite3.connect(self.db_path)
            total = conn.execute("SELECT COUNT(*), SUM(reward), AVG(reward) FROM rewards").fetchone()
            positive = conn.execute("SELECT COUNT(*) FROM rewards WHERE reward > 0").fetchone()[0]
            conn.close()
            return {
                "total_signals": total[0],
                "cumulative_reward": total[1] or 0,
                "avg_reward": total[2] or 0,
                "positive_ratio": positive / max(total[0], 1),
            }
        except:
            return {"total_signals": 0}


# Singleton
_reward = None
def get_reward_signal():
    global _reward
    if _reward is None:
        _reward = RewardSignal()
    return _reward


if __name__ == "__main__":
    rs = RewardSignal()
    # Demo
    rs.reward("agi_loop", "build_binary_search_tree", 0.8, "success")
    rs.reward("agi_loop", "build_game_of_life", -0.3, "failed")
    rs.reward("daddy", "good_response", 1.0, "Daddy liked the answer")
    
    print("Top actions:", rs.get_top_actions(5))
    print("Recent avg:", rs.get_recent_avg())
    print("Stats:", rs.get_total_stats())
