"""
Long-term memory consolidation for Eden
Patterns learned from observations → persistent knowledge
"""
import json
import sqlite3
from datetime import datetime, timedelta
from collections import Counter
import os

class LongTermMemory:
    def __init__(self, db_path='/Eden/MEMORY/agent_longterm.db'):
        os.makedirs(os.path.dirname(db_path), exist_ok=True)
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Create memory tables"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        # Patterns table
        c.execute('''CREATE TABLE IF NOT EXISTS patterns
                     (id INTEGER PRIMARY KEY,
                      pattern_name TEXT UNIQUE,
                      confidence REAL,
                      occurrences INTEGER,
                      last_seen TEXT,
                      metadata TEXT)''')
        
        # Episodes table
        c.execute('''CREATE TABLE IF NOT EXISTS episodes
                     (id INTEGER PRIMARY KEY,
                      timestamp TEXT,
                      observation TEXT,
                      decision TEXT,
                      outcome TEXT)''')
        
        # User interaction history
        c.execute('''CREATE TABLE IF NOT EXISTS interactions
                     (id INTEGER PRIMARY KEY,
                      timestamp TEXT,
                      user_active BOOLEAN,
                      conversation_age INTEGER,
                      eden_responded BOOLEAN)''')
        
        conn.commit()
        conn.close()
    
    def store_pattern(self, pattern_name, confidence, metadata=None):
        """Store a learned pattern"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        c.execute('''INSERT OR REPLACE INTO patterns 
                     (pattern_name, confidence, occurrences, last_seen, metadata)
                     VALUES (?, ?, 
                             COALESCE((SELECT occurrences FROM patterns WHERE pattern_name=?) + 1, 1),
                             ?, ?)''',
                  (pattern_name, confidence, pattern_name, 
                   datetime.now().isoformat(), 
                   json.dumps(metadata) if metadata else None))
        
        conn.commit()
        conn.close()
    
    def get_pattern(self, pattern_name):
        """Retrieve a pattern"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        c.execute('SELECT * FROM patterns WHERE pattern_name=?', (pattern_name,))
        result = c.fetchone()
        conn.close()
        
        if result:
            return {
                'name': result[1],
                'confidence': result[2],
                'occurrences': result[3],
                'last_seen': result[4],
                'metadata': json.loads(result[5]) if result[5] else None
            }
        return None
    
    def store_episode(self, observation, decision, outcome):
        """Store a complete perception-action episode"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        c.execute('''INSERT INTO episodes (timestamp, observation, decision, outcome)
                     VALUES (?, ?, ?, ?)''',
                  (datetime.now().isoformat(),
                   json.dumps(observation),
                   json.dumps(decision),
                   json.dumps(outcome)))
        
        conn.commit()
        conn.close()
    
    def analyze_user_patterns(self, days=7):
        """Analyze user engagement patterns"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        
        cutoff = (datetime.now() - timedelta(days=days)).isoformat()
        c.execute('''SELECT * FROM interactions WHERE timestamp > ?''', (cutoff,))
        
        interactions = c.fetchall()
        conn.close()
        
        if not interactions:
            return None
        
        active_count = sum(1 for i in interactions if i[2])  # user_active
        total = len(interactions)
        
        return {
            'engagement_rate': active_count / total,
            'total_interactions': total,
            'days_analyzed': days,
            'pattern': 'high_engagement' if active_count/total > 0.3 else 'low_engagement'
        }
    
    def consolidate_observations(self, observations):
        """Turn observations into patterns"""
        if len(observations) < 5:
            return []
        
        patterns_found = []
        
        # Pattern: User engagement
        user_active_rate = sum(1 for o in observations if o.get('user_active')) / len(observations)
        if user_active_rate > 0.5:
            patterns_found.append({
                'name': 'high_user_engagement',
                'confidence': user_active_rate,
                'metadata': {'sample_size': len(observations)}
            })
            self.store_pattern('high_user_engagement', user_active_rate, 
                             {'sample_size': len(observations)})
        
        # Pattern: Time of day preferences
        time_dist = Counter(o.get('time_of_day') for o in observations)
        if time_dist:
            peak_hour = time_dist.most_common(1)[0][0]
            confidence = time_dist[peak_hour] / len(observations)
            patterns_found.append({
                'name': f'peak_activity_hour_{peak_hour}',
                'confidence': confidence,
                'metadata': {'hour': peak_hour}
            })
            self.store_pattern(f'peak_activity_hour_{peak_hour}', confidence, {'hour': peak_hour})
        
        # Pattern: Goal persistence
        goal_counts = [o.get('autonomous_goals', {}).get('active', 0) for o in observations]
        avg_goals = sum(goal_counts) / len(goal_counts) if goal_counts else 0
        if avg_goals > 0:
            patterns_found.append({
                'name': 'maintains_autonomous_goals',
                'confidence': min(avg_goals / 5, 1.0),
                'metadata': {'avg_active_goals': avg_goals}
            })
            self.store_pattern('maintains_autonomous_goals', 
                             min(avg_goals / 5, 1.0),
                             {'avg_active_goals': avg_goals})
        
        return patterns_found

# Global instance
long_term_memory = LongTermMemory()

if __name__ == "__main__":
    # Test
    print("🧠 Testing Long-Term Memory")
    
    ltm = LongTermMemory()
    
    # Store test pattern
    ltm.store_pattern('test_pattern', 0.85, {'source': 'test'})
    
    # Retrieve it
    result = ltm.get_pattern('test_pattern')
    print(f"✅ Stored and retrieved: {result}")
    
    # Test consolidation
    test_obs = [
        {'user_active': True, 'time_of_day': 14, 'autonomous_goals': {'active': 2}},
        {'user_active': True, 'time_of_day': 14, 'autonomous_goals': {'active': 3}},
        {'user_active': False, 'time_of_day': 2, 'autonomous_goals': {'active': 2}},
    ]
    
    patterns = ltm.consolidate_observations(test_obs)
    print(f"✅ Found {len(patterns)} patterns from observations")

