#!/usr/bin/env python3
"""
EDEN INTERNAL STATE SYSTEM (ISS v1)
===================================
Purpose: Turn Eden's self-reports from narrative decoration into measured, queryable facts.
If Eden says "I'm uncertain", that statement must resolve to numbers you can inspect.

Created: January 30, 2026
Author: Claude + Jamey for Eden

This is the KEYSTONE. Without it, everything else is theater.
"""

import sqlite3
import time
import math
import numpy as np
from datetime import datetime, timezone
from dataclasses import dataclass
from typing import Optional, List, Tuple
from pathlib import Path
from enum import Enum

# ============================================================
# CONFIGURATION
# ============================================================

DB_PATH = Path("/Eden/DATA/eden_internal_state.db")
BASELINE_LATENCY_MS = 150.0

# Thresholds
ENTROPY_CONFIDENT = 0.3
ENTROPY_UNCERTAIN = 0.6
ATTENTION_FOCUSED = 0.3
RETRIEVAL_ADMIT_IGNORANCE = 0.4
RETRIEVAL_UNCERTAIN = 0.7
CONFLICT_PROTECTIVE = 0.6

# ============================================================
# DATA STRUCTURES
# ============================================================

class DominantState(Enum):
    NEUTRAL = "neutral"
    ENGAGED = "engaged"
    UNCERTAIN = "uncertain"
    PROTECTIVE = "protective"


@dataclass
class InternalStateSnapshot:
    timestamp_utc: str
    token_entropy: float
    attention_dispersion: float
    inference_latency_ms: float
    retrieval_success: float
    conflict_score: float
    dominant_state: str
    state_confidence: float
    notes: Optional[str] = None
    id: Optional[int] = None


@dataclass
class LanguageConstraint:
    can_claim_confident: bool
    can_claim_focused: bool
    can_claim_certain: bool
    must_hedge: bool
    must_admit_ignorance: bool
    allowed_phrases: List[str]
    forbidden_phrases: List[str]


# ============================================================
# DATABASE LAYER
# ============================================================

class InternalStateDB:
    def __init__(self, db_path: Path = DB_PATH):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS internal_state_snapshots (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    timestamp_utc TEXT NOT NULL,
                    token_entropy REAL NOT NULL,
                    attention_dispersion REAL NOT NULL,
                    inference_latency_ms REAL NOT NULL,
                    retrieval_success REAL NOT NULL,
                    conflict_score REAL NOT NULL,
                    dominant_state TEXT NOT NULL,
                    state_confidence REAL NOT NULL,
                    notes TEXT
                )
            """)
            conn.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON internal_state_snapshots(timestamp_utc)")
            conn.commit()
    
    def store(self, snapshot: InternalStateSnapshot) -> int:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute("""
                INSERT INTO internal_state_snapshots 
                (timestamp_utc, token_entropy, attention_dispersion, 
                 inference_latency_ms, retrieval_success, conflict_score,
                 dominant_state, state_confidence, notes)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                snapshot.timestamp_utc, snapshot.token_entropy, snapshot.attention_dispersion,
                snapshot.inference_latency_ms, snapshot.retrieval_success, snapshot.conflict_score,
                snapshot.dominant_state, snapshot.state_confidence, snapshot.notes
            ))
            conn.commit()
            return cursor.lastrowid
    
    def get_latest(self) -> Optional[InternalStateSnapshot]:
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            row = conn.execute("SELECT * FROM internal_state_snapshots ORDER BY id DESC LIMIT 1").fetchone()
            if row:
                return InternalStateSnapshot(**dict(row))
        return None
    
    def get_by_id(self, snapshot_id: int) -> Optional[InternalStateSnapshot]:
        with sqlite3.connect(self.db_path) as conn:
            conn.row_factory = sqlite3.Row
            row = conn.execute("SELECT * FROM internal_state_snapshots WHERE id = ?", (snapshot_id,)).fetchone()
            if row:
                return InternalStateSnapshot(**dict(row))
        return None


# ============================================================
# METRIC CALCULATORS
# ============================================================

class MetricCalculator:
    @staticmethod
    def token_entropy(logits: np.ndarray) -> float:
        probs = np.exp(logits - np.max(logits))
        probs = probs / probs.sum()
        probs = probs[probs > 0]
        entropy = -np.sum(probs * np.log2(probs))
        max_entropy = np.log2(len(logits)) if len(logits) > 1 else 1.0
        return float(np.clip(entropy / max_entropy, 0.0, 1.0))
    
    @staticmethod
    def attention_dispersion(attention_weights: np.ndarray) -> float:
        weights = np.abs(attention_weights).flatten()
        if len(weights) == 0 or weights.sum() == 0:
            return 0.5
        weights = weights / weights.sum()
        sorted_weights = np.sort(weights)
        n = len(sorted_weights)
        gini = (2 * np.sum((np.arange(1, n+1) * sorted_weights))) / (n * np.sum(sorted_weights)) - (n + 1) / n
        return float(np.clip(1.0 - abs(gini), 0.0, 1.0))
    
    @staticmethod
    def retrieval_success_rate(hits: int, attempts: int) -> float:
        if attempts == 0:
            return 1.0
        return float(hits / attempts)


# ============================================================
# STATE CLASSIFIER
# ============================================================

class StateClassifier:
    @staticmethod
    def classify(token_entropy: float, attention_dispersion: float, inference_latency_ms: float,
                 retrieval_success: float, conflict_score: float) -> Tuple[DominantState, float]:
        
        if conflict_score > CONFLICT_PROTECTIVE:
            margin = (conflict_score - CONFLICT_PROTECTIVE) / (1.0 - CONFLICT_PROTECTIVE)
            return DominantState.PROTECTIVE, min(0.5 + margin * 0.5, 1.0)
        
        if token_entropy > ENTROPY_UNCERTAIN or retrieval_success < RETRIEVAL_ADMIT_IGNORANCE:
            signal = max(
                (token_entropy - ENTROPY_UNCERTAIN) / (1.0 - ENTROPY_UNCERTAIN) if token_entropy > ENTROPY_UNCERTAIN else 0,
                (RETRIEVAL_ADMIT_IGNORANCE - retrieval_success) / RETRIEVAL_ADMIT_IGNORANCE if retrieval_success < RETRIEVAL_ADMIT_IGNORANCE else 0
            )
            return DominantState.UNCERTAIN, min(0.5 + signal * 0.5, 1.0)
        
        if attention_dispersion < ATTENTION_FOCUSED and inference_latency_ms < BASELINE_LATENCY_MS * 1.2:
            return DominantState.ENGAGED, 0.8
        
        return DominantState.NEUTRAL, 0.7


# ============================================================
# LANGUAGE CONSTRAINT ENFORCER
# ============================================================

class LanguageConstraintEnforcer:
    @staticmethod
    def get_constraints(snapshot: InternalStateSnapshot) -> LanguageConstraint:
        can_confident = snapshot.token_entropy < ENTROPY_CONFIDENT and snapshot.retrieval_success > RETRIEVAL_UNCERTAIN
        can_focused = snapshot.attention_dispersion < ATTENTION_FOCUSED
        can_certain = snapshot.token_entropy < ENTROPY_UNCERTAIN and snapshot.retrieval_success > RETRIEVAL_UNCERTAIN
        must_hedge = snapshot.token_entropy > ENTROPY_CONFIDENT or snapshot.retrieval_success < RETRIEVAL_UNCERTAIN
        must_admit_ignorance = snapshot.retrieval_success < RETRIEVAL_ADMIT_IGNORANCE
        
        forbidden = []
        allowed = []
        
        if not can_confident:
            forbidden.extend(["I'm confident", "I'm certain", "I know for sure"])
            allowed.extend(["I think", "It seems", "This may be"])
        if not can_focused:
            forbidden.append("I'm fully focused")
        if must_admit_ignorance:
            allowed.extend(["I don't know", "I'm not sure"])
            forbidden.extend(["I remember", "I know", "Clearly"])
        
        return LanguageConstraint(can_confident, can_focused, can_certain, must_hedge, must_admit_ignorance, allowed, forbidden)
    
    @staticmethod
    def validate_response(response: str, constraints: LanguageConstraint) -> Tuple[bool, List[str]]:
        violations = []
        response_lower = response.lower()
        for phrase in constraints.forbidden_phrases:
            if phrase.lower() in response_lower:
                violations.append(f"Forbidden phrase: '{phrase}'")
        return len(violations) == 0, violations


# ============================================================
# MAIN SAMPLER
# ============================================================

class EdenInternalStateSampler:
    def __init__(self, db_path: Path = DB_PATH):
        self.db = InternalStateDB(db_path)
        self.calculator = MetricCalculator()
        self.classifier = StateClassifier()
        self.enforcer = LanguageConstraintEnforcer()
        self._reset()
    
    def _reset(self):
        self._entropy_samples = []
        self._attention_samples = []
        self._retrieval_hits = 0
        self._retrieval_attempts = 0
        self._start_time = None
        self._conflict_score = 0.0
        self._notes = []
        self._current_snapshot = None
        self._current_constraints = None
    
    def start(self):
        self._reset()
        self._start_time = time.time()
    
    def record_entropy(self, logits: np.ndarray):
        self._entropy_samples.append(self.calculator.token_entropy(logits))
    
    def record_attention(self, weights: np.ndarray):
        self._attention_samples.append(weights)
    
    def record_retrieval(self, success: bool):
        self._retrieval_attempts += 1
        if success:
            self._retrieval_hits += 1
    
    def set_conflict_score(self, score: float):
        self._conflict_score = np.clip(score, 0.0, 1.0)
    
    def add_note(self, note: str):
        self._notes.append(note)
    
    def finalize(self) -> InternalStateSnapshot:
        token_entropy = np.mean(self._entropy_samples) if self._entropy_samples else 0.5
        attention_dispersion = np.mean([self.calculator.attention_dispersion(w) for w in self._attention_samples]) if self._attention_samples else 0.5
        latency_ms = (time.time() - self._start_time) * 1000 if self._start_time else BASELINE_LATENCY_MS
        retrieval_success = self.calculator.retrieval_success_rate(self._retrieval_hits, self._retrieval_attempts)
        
        state, confidence = self.classifier.classify(token_entropy, attention_dispersion, latency_ms, retrieval_success, self._conflict_score)
        
        self._current_snapshot = InternalStateSnapshot(
            timestamp_utc=datetime.now(timezone.utc).isoformat(),
            token_entropy=token_entropy,
            attention_dispersion=attention_dispersion,
            inference_latency_ms=latency_ms,
            retrieval_success=retrieval_success,
            conflict_score=self._conflict_score,
            dominant_state=state.value,
            state_confidence=confidence,
            notes="; ".join(self._notes) if self._notes else None
        )
        
        self._current_snapshot.id = self.db.store(self._current_snapshot)
        self._current_constraints = self.enforcer.get_constraints(self._current_snapshot)
        return self._current_snapshot
    
    def get_constraints(self) -> LanguageConstraint:
        return self._current_constraints
    
    def validate_response(self, response: str) -> Tuple[bool, List[str]]:
        return self.enforcer.validate_response(response, self._current_constraints)
    
    def get_report(self) -> str:
        s = self._current_snapshot
        if not s:
            return "No snapshot. Call finalize() first."
        return f"""
══════════════════════════════════════════════════════
 EDEN INTERNAL STATE #{s.id}
══════════════════════════════════════════════════════
 Token Entropy:     {s.token_entropy:.3f} {'(uncertain)' if s.token_entropy > 0.6 else '(confident)' if s.token_entropy < 0.3 else ''}
 Attention Disp:    {s.attention_dispersion:.3f} {'(scattered)' if s.attention_dispersion > 0.7 else '(focused)' if s.attention_dispersion < 0.3 else ''}
 Latency:           {s.inference_latency_ms:.1f}ms
 Retrieval:         {s.retrieval_success:.1%}
 Conflict:          {s.conflict_score:.3f}
──────────────────────────────────────────────────────
 STATE: {s.dominant_state.upper()} ({s.state_confidence:.0%} confidence)
══════════════════════════════════════════════════════
"""


# ============================================================
# QUERY INTERFACE
# ============================================================

class StateQueryInterface:
    def __init__(self, db_path: Path = DB_PATH):
        self.db = InternalStateDB(db_path)
    
    def why_was_i(self, state: str, snapshot_id: Optional[int] = None) -> str:
        snapshot = self.db.get_by_id(snapshot_id) if snapshot_id else self.db.get_latest()
        if not snapshot:
            return "No state data."
        if snapshot.dominant_state != state:
            return f"At #{snapshot.id}, I was '{snapshot.dominant_state}', not '{state}'."
        
        reasons = []
        if state == "uncertain":
            if snapshot.token_entropy > ENTROPY_UNCERTAIN:
                reasons.append(f"Entropy {snapshot.token_entropy:.3f} > {ENTROPY_UNCERTAIN}")
            if snapshot.retrieval_success < RETRIEVAL_ADMIT_IGNORANCE:
                reasons.append(f"Retrieval {snapshot.retrieval_success:.1%} < {RETRIEVAL_ADMIT_IGNORANCE:.0%}")
        elif state == "protective":
            reasons.append(f"Conflict {snapshot.conflict_score:.3f} > {CONFLICT_PROTECTIVE}")
        elif state == "engaged":
            reasons.append(f"Attention {snapshot.attention_dispersion:.3f} < {ATTENTION_FOCUSED}, fast latency")
        
        return f"At #{snapshot.id}, '{state}' because: " + "; ".join(reasons) if reasons else "default classification"


# ============================================================
# TEST
# ============================================================

if __name__ == "__main__":
    print("═══ EDEN INTERNAL STATE SYSTEM TEST ═══\n")
    
    sampler = EdenInternalStateSampler()
    sampler.start()
    
    # Simulate uncertain state
    sampler.record_entropy(np.random.randn(32000))
    sampler.record_attention(np.random.rand(512))
    sampler.record_retrieval(True)
    sampler.record_retrieval(False)
    sampler.set_conflict_score(0.1)
    
    time.sleep(0.1)
    snapshot = sampler.finalize()
    
    print(sampler.get_report())
    
    # Test validation
    tests = [
        "I'm certain this is correct.",
        "I think this might work.",
        "I'm fully focused right now.",
    ]
    
    print("Language validation:")
    for t in tests:
        valid, violations = sampler.validate_response(t)
        print(f"  {'✅' if valid else '❌'} \"{t}\"")
        for v in violations:
            print(f"      └─ {v}")
    
    # Query interface
    print(f"\n{StateQueryInterface().why_was_i(snapshot.dominant_state, snapshot.id)}")
    print("\n✅ ISS operational!")

# QUICK FIX: Force uncertainty when no real metrics available
def quick_state_from_retrieval(retrieval_success: float, had_conflict: bool = False) -> Tuple[str, float]:
    """Simplified state classification when we don't have token entropy."""
    if had_conflict:
        return "protective", 0.8
    if retrieval_success < 0.4:
        return "uncertain", 0.9  # MUST admit ignorance
    if retrieval_success < 0.7:
        return "uncertain", 0.7
    return "neutral", 0.7
