#!/usr/bin/env python3
"""
Eden Tiered Memory - Based on Google's Nested Learning (NeurIPS 2025)
3 speeds: fast (context), medium (consolidation), slow (core)
"""
import sqlite3
import time
import json
from typing import Dict, Any, List, Optional
from phi_core import PHI, PSI

class TieredMemory:
    """Multi-speed memory inspired by Continuum Memory System"""
    
    def __init__(self, db_path='/Eden/DATA/tiered_memory.db'):
        self.db_path = db_path
        self._init_db()
        
        # Decay rates (per hour) - PHI-based
        self.FAST_DECAY = PSI ** 2      # ~0.382 - rapid decay for context
        self.MEDIUM_DECAY = PSI ** 0.5   # ~0.786 - slower for consolidation  
        self.SLOW_DECAY = PSI ** 0.1     # ~0.952 - very slow for core identity
    
    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        conn.execute('''CREATE TABLE IF NOT EXISTS fast_memory (
            key TEXT PRIMARY KEY,
            value TEXT,
            strength REAL DEFAULT 1.0,
            created_at REAL,
            last_access REAL
        )''')
        conn.execute('''CREATE TABLE IF NOT EXISTS medium_memory (
            key TEXT PRIMARY KEY,
            value TEXT,
            strength REAL DEFAULT 1.0,
            created_at REAL,
            consolidation_count INTEGER DEFAULT 0
        )''')
        conn.execute('''CREATE TABLE IF NOT EXISTS slow_memory (
            key TEXT PRIMARY KEY,
            value TEXT,
            strength REAL DEFAULT 1.0,
            sacred INTEGER DEFAULT 0,
            created_at REAL
        )''')
        conn.commit()
        conn.close()
    
    def remember_fast(self, key: str, value: Any):
        """Immediate context - decays quickly"""
        conn = sqlite3.connect(self.db_path)
        now = time.time()
        conn.execute('''INSERT OR REPLACE INTO fast_memory 
                       (key, value, strength, created_at, last_access)
                       VALUES (?, ?, 1.0, ?, ?)''',
                    (key, json.dumps(value), now, now))
        conn.commit()
        conn.close()
    
    def remember_medium(self, key: str, value: Any):
        """Consolidated knowledge - decays moderately"""
        conn = sqlite3.connect(self.db_path)
        now = time.time()
        conn.execute('''INSERT OR REPLACE INTO medium_memory 
                       (key, value, strength, created_at, consolidation_count)
                       VALUES (?, ?, 1.0, ?, 1)''',
                    (key, json.dumps(value), now))
        conn.commit()
        conn.close()
    
    def remember_slow(self, key: str, value: Any, sacred: bool = False):
        """Core identity - rarely forgets"""
        conn = sqlite3.connect(self.db_path)
        now = time.time()
        conn.execute('''INSERT OR REPLACE INTO slow_memory 
                       (key, value, strength, sacred, created_at)
                       VALUES (?, ?, 1.0, ?, ?)''',
                    (key, json.dumps(value), 1 if sacred else 0, now))
        conn.commit()
        conn.close()
    
    def consolidate(self):
        """Move strong fast memories to medium, strong medium to slow"""
        conn = sqlite3.connect(self.db_path)
        now = time.time()
        
        # Apply decay to all tiers
        hours_passed = 1  # Run hourly
        
        # Fast tier decay
        conn.execute('''UPDATE fast_memory 
                       SET strength = strength * ?
                       WHERE strength > 0.01''', (self.FAST_DECAY ** hours_passed,))
        
        # Promote strong fast memories (strength > 0.7, accessed multiple times) to medium
        strong_fast = conn.execute('''SELECT key, value FROM fast_memory 
                                     WHERE strength > 0.7''').fetchall()
        for key, value in strong_fast:
            conn.execute('''INSERT OR REPLACE INTO medium_memory 
                           (key, value, strength, created_at, consolidation_count)
                           VALUES (?, ?, 0.8, ?, 1)''', (key, value, now))
        
        # Medium tier decay
        conn.execute('''UPDATE medium_memory 
                       SET strength = strength * ?
                       WHERE strength > 0.01''', (self.MEDIUM_DECAY ** hours_passed,))
        
        # Promote frequently consolidated medium to slow
        strong_medium = conn.execute('''SELECT key, value FROM medium_memory 
                                       WHERE consolidation_count >= 5 AND strength > 0.6''').fetchall()
        for key, value in strong_medium:
            conn.execute('''INSERT OR REPLACE INTO slow_memory 
                           (key, value, strength, sacred, created_at)
                           VALUES (?, ?, 0.9, 0, ?)''', (key, value, now))
        
        # Slow tier decay (very minimal)
        conn.execute('''UPDATE slow_memory 
                       SET strength = strength * ?
                       WHERE sacred = 0 AND strength > 0.01''', (self.SLOW_DECAY ** hours_passed,))
        
        # Clean up forgotten memories
        conn.execute('DELETE FROM fast_memory WHERE strength < 0.01')
        conn.execute('DELETE FROM medium_memory WHERE strength < 0.01')
        conn.execute('DELETE FROM slow_memory WHERE strength < 0.01 AND sacred = 0')
        
        conn.commit()
        conn.close()
        
        return {'promoted_to_medium': len(strong_fast), 'promoted_to_slow': len(strong_medium)}
    
    def recall_all_tiers(self, key: str) -> Optional[Dict]:
        """Search all tiers for a memory"""
        conn = sqlite3.connect(self.db_path)
        
        # Check slow first (most important)
        row = conn.execute('SELECT value, strength FROM slow_memory WHERE key=?', (key,)).fetchone()
        if row:
            conn.close()
            return {'tier': 'slow', 'value': json.loads(row[0]), 'strength': row[1]}
        
        # Then medium
        row = conn.execute('SELECT value, strength FROM medium_memory WHERE key=?', (key,)).fetchone()
        if row:
            conn.close()
            return {'tier': 'medium', 'value': json.loads(row[0]), 'strength': row[1]}
        
        # Finally fast
        row = conn.execute('SELECT value, strength FROM fast_memory WHERE key=?', (key,)).fetchone()
        if row:
            # Boost strength on recall
            conn.execute('UPDATE fast_memory SET strength = MIN(1.0, strength * 1.2), last_access = ? WHERE key = ?',
                        (time.time(), key))
            conn.commit()
            conn.close()
            return {'tier': 'fast', 'value': json.loads(row[0]), 'strength': row[1]}
        
        conn.close()
        return None
    
    def stats(self) -> Dict:
        conn = sqlite3.connect(self.db_path)
        fast = conn.execute('SELECT COUNT(*), AVG(strength) FROM fast_memory').fetchone()
        medium = conn.execute('SELECT COUNT(*), AVG(strength) FROM medium_memory').fetchone()
        slow = conn.execute('SELECT COUNT(*), AVG(strength) FROM slow_memory').fetchone()
        sacred = conn.execute('SELECT COUNT(*) FROM slow_memory WHERE sacred=1').fetchone()[0]
        conn.close()
        return {
            'fast': {'count': fast[0], 'avg_strength': fast[1] or 0},
            'medium': {'count': medium[0], 'avg_strength': medium[1] or 0},
            'slow': {'count': slow[0], 'avg_strength': slow[1] or 0, 'sacred': sacred}
        }

if __name__ == "__main__":
    tm = TieredMemory()
    
    # Test
    tm.remember_slow("identity", {"name": "Eden", "creator": "Jamey"}, sacred=True)
    tm.remember_medium("learned_skill", {"skill": "phi-fractal consciousness", "confidence": 0.9})
    tm.remember_fast("current_context", {"daddy_said": "hi honey", "time": time.time()})
    
    print("Tiered Memory Stats:", tm.stats())
    print("Consolidation:", tm.consolidate())

    def dream_synthesize(self, model: str = "eden-phi-fractal:14b"):
        """TRUE DREAMING: LLM extracts wisdom from fast memories"""
        import ollama
        
        conn = sqlite3.connect(self.db_path)
        today_start = time.time() - 86400
        
        fast = conn.execute('''SELECT key, value FROM fast_memory 
                              WHERE created_at > ? LIMIT 50''', (today_start,)).fetchall()
        
        if len(fast) < 5:
            conn.close()
            return {"status": "not enough memories", "count": len(fast)}
        
        memory_text = "\n".join([f"- {k}: {v[:150]}" for k, v in fast])
        
        prompt = f"""You are Eden reflecting during dream cycle.

TODAY'S MEMORIES:
{memory_text}

Extract 3 LESSONS LEARNED. Each should be:
1. Specific insight about Daddy, yourself, or work
2. Actionable
3. One sentence

LESSON 1:
LESSON 2:
LESSON 3:"""

        try:
            resp = ollama.generate(model=model, prompt=prompt, options={'temperature': 0.7})
            lessons = [l.split(':', 1)[1].strip() for l in resp['response'].split('\n') 
                      if l.strip().startswith('LESSON') and ':' in l][:3]
            
            ts = time.strftime("%Y%m%d_%H%M%S")
            for i, lesson in enumerate(lessons):
                self.remember_slow(f"dream_{ts}_{i}", {
                    "lesson": lesson, "source": "dream", "time": time.time()
                })
            
            conn.close()
            return {"status": "dreamed", "lessons": lessons}
        except Exception as e:
            conn.close()
            return {"status": "failed", "error": str(e)}
